refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -32,8 +32,10 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@@ -46,9 +48,9 @@ type State struct {
|
||||
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
|
||||
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
|
||||
|
||||
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
|
||||
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
|
||||
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
|
||||
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
|
||||
|
||||
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
|
||||
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
|
||||
@@ -71,8 +73,8 @@ func init() {
|
||||
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
|
||||
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
|
||||
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
|
||||
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
|
||||
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
|
||||
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
|
||||
@@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
|
||||
return s.GroupChoices[nodeKey]
|
||||
}
|
||||
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
choice, ok := choices[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if choice == -1 { // this group picks none of the elements
|
||||
return nodes.FieldNotStream, nil
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
sInfos, ok := s.SourceInfos[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
groupInfo, ok := sInfos[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if groupInfo.SubSources == nil {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
}
|
||||
|
||||
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
}
|
||||
|
||||
if subInfo.FieldType != nodes.FieldMaybeStream {
|
||||
if subInfo.FieldType != schema2.FieldMaybeStream {
|
||||
return subInfo.FieldType, nil
|
||||
}
|
||||
|
||||
@@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
|
||||
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
|
||||
}
|
||||
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
|
||||
result := make(map[string]nodes.FieldStreamType)
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
|
||||
result := make(map[string]schema2.FieldStreamType)
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return result, nil
|
||||
@@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
|
||||
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
|
||||
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
|
||||
ExecutedNodes: make(map[vo.NodeKey]bool),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
|
||||
GroupChoices: make(map[vo.NodeKey]map[string]int),
|
||||
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
|
||||
LLMToResumeData: make(map[vo.NodeKey]string),
|
||||
@@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePreHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
|
||||
@@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
}
|
||||
return in, nil
|
||||
})
|
||||
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
|
||||
} else if entity.NodeMetaByNodeType(s.Type).IsComposite {
|
||||
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
|
||||
state.Inputs[s.Key] = in
|
||||
@@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
}
|
||||
|
||||
if len(handlers) > 0 || !stream {
|
||||
handlerForVars := s.statePreHandlerForVars()
|
||||
handlerForVars := statePreHandlerForVars(s)
|
||||
if handlerForVars != nil {
|
||||
handlers = append(handlers, handlerForVars)
|
||||
}
|
||||
@@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
|
||||
if s.Type == entity.NodeTypeVariableAggregator {
|
||||
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
state.SourceInfos[s.Key] = s.FullSources
|
||||
return in, nil
|
||||
})
|
||||
}
|
||||
|
||||
handlerForVars := s.streamStatePreHandlerForVars()
|
||||
handlerForVars := streamStatePreHandlerForVars(s)
|
||||
if handlerForVars != nil {
|
||||
streamHandlers = append(streamHandlers, handlerForVars)
|
||||
}
|
||||
@@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
|
||||
func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
@@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
@@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// if it does not have source info, do not add this pre handler
|
||||
if s.Configs == nil {
|
||||
return nil
|
||||
@@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
|
||||
return nil
|
||||
case entity.NodeTypeExit:
|
||||
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
terminatePlan := s.Configs.(*exit.Config).TerminatePlan
|
||||
if terminatePlan != vo.ReturnVariables {
|
||||
return nil
|
||||
}
|
||||
@@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
|
||||
}
|
||||
|
||||
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
sourceInfo := s.FullSources
|
||||
if len(sourceInfo) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
|
||||
var (
|
||||
anyStream bool
|
||||
checker func(source *nodes.SourceInfo) bool
|
||||
checker func(source *schema2.SourceInfo) bool
|
||||
)
|
||||
checker = func(source *nodes.SourceInfo) bool {
|
||||
if source.FieldType != nodes.FieldNotStream {
|
||||
checker = func(source *schema2.SourceInfo) bool {
|
||||
if source.FieldType != schema2.FieldNotStream {
|
||||
return true
|
||||
}
|
||||
for _, subSource := range source.SubSources {
|
||||
@@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
resolved := map[string]resolvedStreamSource{}
|
||||
|
||||
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
if source.IsIntermediate {
|
||||
result = &resolvedStreamSource{
|
||||
intermediate: true,
|
||||
@@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
}
|
||||
|
||||
streamType := source.FieldType
|
||||
if streamType == nodes.FieldMaybeStream {
|
||||
if streamType == schema2.FieldMaybeStream {
|
||||
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if streamType == nodes.FieldNotStream {
|
||||
if streamType == schema2.FieldNotStream {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePostHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
|
||||
@@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.streamStatePostHandlerForVars()
|
||||
forVars := streamStatePostHandlerForVars(s)
|
||||
if forVars != nil {
|
||||
streamHandlers = append(streamHandlers, forVars)
|
||||
}
|
||||
@@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.statePostHandlerForVars()
|
||||
forVars := statePostHandlerForVars(s)
|
||||
if forVars != nil {
|
||||
handlers = append(handlers, forVars)
|
||||
}
|
||||
@@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return compose.WithStatePostHandler(handler)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
|
||||
func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variable,
|
||||
// use the state's variableHandler to get the variables and set them to the output
|
||||
var vars []*vo.FieldInfo
|
||||
@@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
|
||||
var vars []*vo.FieldInfo
|
||||
for _, output := range s.OutputSources {
|
||||
|
||||
Reference in New Issue
Block a user