From 72656e4fd16b3aa999ad6f3c99f74fb6967c883b Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Mon, 4 Aug 2025 20:19:36 +0800 Subject: [PATCH] fix: app var store can be resumed (#546) --- .../plugin/pluginmock/plugin_mock.go | 17 +++- .../internal/canvas/adaptor/canvas_test.go | 7 +- .../domain/workflow/internal/compose/state.go | 98 ++++++++++--------- .../workflow/internal/execute/context.go | 61 ++++++++++-- .../nodes/variableassigner/variable_assign.go | 52 ++-------- 5 files changed, 137 insertions(+), 98 deletions(-) diff --git a/backend/domain/workflow/crossdomain/plugin/pluginmock/plugin_mock.go b/backend/domain/workflow/crossdomain/plugin/pluginmock/plugin_mock.go index f08e5bb8..3a5e9522 100644 --- a/backend/domain/workflow/crossdomain/plugin/pluginmock/plugin_mock.go +++ b/backend/domain/workflow/crossdomain/plugin/pluginmock/plugin_mock.go @@ -29,8 +29,9 @@ import ( context "context" reflect "reflect" - plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" schema "github.com/cloudwego/eino/schema" + plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" + vo "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" gomock "go.uber.org/mock/gomock" ) @@ -103,6 +104,20 @@ func (mr *MockServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockService)(nil).GetPluginToolsInfo), ctx, req) } +// UnwrapArrayItemFieldsInVariable mocks base method. +func (m *MockService) UnwrapArrayItemFieldsInVariable(v *vo.Variable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnwrapArrayItemFieldsInVariable", v) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnwrapArrayItemFieldsInVariable indicates an expected call of UnwrapArrayItemFieldsInVariable. +func (mr *MockServiceMockRecorder) UnwrapArrayItemFieldsInVariable(v any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapArrayItemFieldsInVariable", reflect.TypeOf((*MockService)(nil).UnwrapArrayItemFieldsInVariable), v) +} + // MockInvokableTool is a mock of InvokableTool interface. type MockInvokableTool struct { ctrl *gomock.Controller diff --git a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go index 4f0e2e78..bf8311f8 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go +++ b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go @@ -655,11 +655,10 @@ func TestKnowledgeNodes(t *testing.T) { mockKnowledgeOperator.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(rResponse, nil) mockGlobalAppVarStore := mockvar.NewMockStore(ctrl) mockGlobalAppVarStore.EXPECT().Get(gomock.Any(), gomock.Any()).Return("v1", nil).AnyTimes() - mockGlobalAppVarStore.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - variable.SetVariableHandler(&variable.Handler{ - AppVarStore: mockGlobalAppVarStore, - }) + variable.SetVariableHandler(&variable.Handler{AppVarStore: mockGlobalAppVarStore}) + + mockey.Mock(execute.GetAppVarStore).Return(&execute.AppVariables{Vars: map[string]any{}}).Build() ctx := t.Context() ctx = ctxcache.Init(ctx) diff --git a/backend/domain/workflow/internal/compose/state.go b/backend/domain/workflow/internal/compose/state.go index cc5e1c64..bd4b691a 100644 --- a/backend/domain/workflow/internal/compose/state.go +++ b/backend/domain/workflow/internal/compose/state.go @@ -34,7 +34,6 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -53,7 +52,6 @@ type State struct { ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"` LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"` - AppVariableStore *variableassigner.AppVariables `json:"variable_app_store,omitempty"` } func init() { @@ -85,15 +83,7 @@ func init() { _ = compose.RegisterSerializableType[vo.SyncPattern]("sync_pattern") _ = compose.RegisterSerializableType[vo.Locator]("wf_locator") _ = compose.RegisterSerializableType[vo.BizType]("biz_type") - _ = compose.RegisterSerializableType[*variableassigner.AppVariables]("app_variables") -} - -func (s *State) SetAppVariableValue(key string, value any) { - s.AppVariableStore.Set(key, value) -} - -func (s *State) GetAppVariableValue(key string) (any, bool) { - return s.AppVariableStore.Get(key) + _ = compose.RegisterSerializableType[*execute.AppVariables]("app_variables") } func (s *State) AddQuestion(nodeKey vo.NodeKey, question *qa.Question) { @@ -271,19 +261,6 @@ func (s *State) NodeExecuted(key vo.NodeKey) bool { func GenState() compose.GenLocalState[*State] { return func(ctx context.Context) (state *State) { - var parentState *State - _ = compose.ProcessState(ctx, func(ctx context.Context, s *State) error { - parentState = s - return nil - }) - - var appVariableStore *variableassigner.AppVariables - if parentState == nil { - appVariableStore = variableassigner.NewAppVariables() - } else { - appVariableStore = parentState.AppVariableStore - } - return &State{ Answers: make(map[vo.NodeKey][]string), Questions: make(map[vo.NodeKey][]*qa.Question), @@ -296,7 +273,6 @@ func GenState() compose.GenLocalState[*State] { GroupChoices: make(map[vo.NodeKey]map[string]int), ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent), LLMToResumeData: make(map[vo.NodeKey]string), - AppVariableStore: appVariableStore, } } } @@ -422,10 +398,9 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string intermediateVarStore := &nodes.ParentIntermediateStore{} return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) { - opts := make([]variable.OptionFn, 0, 1) - - if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil { + var exeCtx *execute.Context + if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil { exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{ AgentID: exeCfg.AgentID, @@ -452,13 +427,20 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string case vo.GlobalAPP: var ok bool path := strings.Join(input.Source.Ref.FromPath, ".") - if v, ok = state.GetAppVariableValue(path); !ok { + if exeCtx == nil || exeCtx.AppVarStore == nil { v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) if err != nil { return nil, err } + } else { + if v, ok = exeCtx.AppVarStore.Get(path); !ok { + v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) + if err != nil { + return nil, err + } - state.SetAppVariableValue(path, v) + exeCtx.AppVarStore.Set(path, v) + } } default: return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType) @@ -494,15 +476,18 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle var ( variables = make(map[string]any) opts = make([]variable.OptionFn, 0, 1) - exeCfg = execute.GetExeCtx(ctx).RootCtx.ExeCfg + exeCtx *execute.Context ) - opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{ - AgentID: exeCfg.AgentID, - AppID: exeCfg.AppID, - ConnectorID: exeCfg.ConnectorID, - ConnectorUID: exeCfg.ConnectorUID, - })) + if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil { + exeCfg := exeCtx.RootCtx.ExeCfg + opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{ + AgentID: exeCfg.AgentID, + AppID: exeCfg.AppID, + ConnectorID: exeCfg.ConnectorID, + ConnectorUID: exeCfg.ConnectorUID, + })) + } for _, input := range vars { if input == nil { @@ -518,13 +503,20 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle case vo.GlobalAPP: var ok bool path := strings.Join(input.Source.Ref.FromPath, ".") - if v, ok = state.GetAppVariableValue(path); !ok { + if exeCtx == nil || exeCtx.AppVarStore == nil { v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) if err != nil { return nil, err } + } else { + if v, ok = exeCtx.AppVarStore.Get(path); !ok { + v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) + if err != nil { + return nil, err + } - state.SetAppVariableValue(path, v) + exeCtx.AppVarStore.Set(path, v) + } } default: return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType) @@ -776,7 +768,8 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) { opts := make([]variable.OptionFn, 0, 1) - if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil { + var exeCtx *execute.Context + if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil { exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{ AgentID: exeCfg.AgentID, @@ -801,13 +794,20 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri case vo.GlobalAPP: var ok bool path := strings.Join(input.Source.Ref.FromPath, ".") - if v, ok = state.GetAppVariableValue(path); !ok { + if exeCtx == nil || exeCtx.AppVarStore == nil { v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) if err != nil { return nil, err } + } else { + if v, ok = exeCtx.AppVarStore.Get(path); !ok { + v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) + if err != nil { + return nil, err + } - state.SetAppVariableValue(path, v) + exeCtx.AppVarStore.Set(path, v) + } } default: return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType) @@ -845,9 +845,10 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand var ( variables = make(map[string]any) opts = make([]variable.OptionFn, 0, 1) + exeCtx *execute.Context ) - if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil { + if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil { exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{ AgentID: exeCfg.AgentID, @@ -869,13 +870,20 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand case vo.GlobalAPP: var ok bool path := strings.Join(input.Source.Ref.FromPath, ".") - if v, ok = state.GetAppVariableValue(path); !ok { + if exeCtx == nil || exeCtx.AppVarStore == nil { v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) if err != nil { return nil, err } + } else { + if v, ok = exeCtx.AppVarStore.Get(path); !ok { + v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...) + if err != nil { + return nil, err + } - state.SetAppVariableValue(path, v) + exeCtx.AppVarStore.Set(path, v) + } } default: return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType) diff --git a/backend/domain/workflow/internal/execute/context.go b/backend/domain/workflow/internal/execute/context.go index 5fc6e50b..93e2f2e3 100644 --- a/backend/domain/workflow/internal/execute/context.go +++ b/backend/domain/workflow/internal/execute/context.go @@ -22,6 +22,7 @@ import ( "fmt" "strconv" "strings" + "sync" "time" "github.com/cloudwego/eino/compose" @@ -45,6 +46,8 @@ type Context struct { StartTime int64 // UnixMilli CheckPointID string + + AppVarStore *AppVariables } type RootCtx struct { @@ -106,12 +109,15 @@ func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Contex } storedCtx.ResumeEvent = h.resumeEvent + currentC := GetExeCtx(ctx) + if currentC != nil { + // restore the parent-child relationship between token collectors + if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil { + currentTokenCollector := currentC.TokenCollector + storedCtx.TokenCollector.Parent = currentTokenCollector + } - // restore the parent-child relationship between token collectors - if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil { - currentC := GetExeCtx(ctx) - currentTokenCollector := currentC.TokenCollector - storedCtx.TokenCollector.Parent = currentTokenCollector + storedCtx.AppVarStore = currentC.AppVarStore } return context.WithValue(ctx, contextKey{}, storedCtx), nil @@ -150,13 +156,16 @@ func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent } + currentC := GetExeCtx(ctx) + // restore the parent-child relationship between token collectors if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil { - currentC := GetExeCtx(ctx) currentTokenCollector := currentC.TokenCollector storedCtx.TokenCollector.Parent = currentTokenCollector } + storedCtx.AppVarStore = currentC.AppVarStore + storedCtx.NodeCtx.CurrentRetryCount = 0 return context.WithValue(ctx, contextKey{}, storedCtx), nil @@ -184,6 +193,7 @@ func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context existingC := GetExeCtx(ctx) if existingC != nil { storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent + storedCtx.AppVarStore = existingC.AppVarStore } // restore the parent-child relationship between token collectors @@ -213,6 +223,7 @@ func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector), StartTime: time.Now().UnixMilli(), + AppVarStore: NewAppVariables(), } if h.requireCheckpoint { @@ -266,6 +277,7 @@ func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireChec TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector), CheckPointID: newCheckpointID, StartTime: time.Now().UnixMilli(), + AppVarStore: c.AppVarStore, } if requireCheckpoint { @@ -308,6 +320,7 @@ func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string, BatchInfo: c.BatchInfo, StartTime: time.Now().UnixMilli(), CheckPointID: c.CheckPointID, + AppVarStore: c.AppVarStore, } if c.NodeCtx == nil { // node within top level workflow, also not under composite node @@ -354,6 +367,7 @@ func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string CompositeNodeKey: c.NodeCtx.NodeKey, }, CheckPointID: newCheckpointID, + AppVarStore: c.AppVarStore, }), newCheckpointID } @@ -363,3 +377,38 @@ type ExeContextStore interface { GetWorkflowCtx() (*Context, bool, error) SetWorkflowCtx(value *Context) error } + +type AppVariables struct { + Vars map[string]any + mu sync.RWMutex +} + +func NewAppVariables() *AppVariables { + return &AppVariables{ + Vars: make(map[string]any), + } +} + +func (av *AppVariables) Set(key string, value any) { + av.mu.Lock() + av.Vars[key] = value + av.mu.Unlock() +} + +func (av *AppVariables) Get(key string) (any, bool) { + av.mu.RLock() + defer av.mu.RUnlock() + + if value, ok := av.Vars[key]; ok { + return value, ok + } + return nil, false +} + +func GetAppVarStore(ctx context.Context) *AppVariables { + c := ctx.Value(contextKey{}) + if c == nil { + return nil + } + return c.(*Context).AppVarStore +} diff --git a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go index ba3ecec1..06b00412 100644 --- a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go +++ b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go @@ -18,9 +18,9 @@ package variableassigner import ( "context" + "errors" "fmt" "strings" - "sync" "github.com/cloudwego/eino/compose" @@ -32,38 +32,6 @@ import ( "github.com/coze-dev/coze-studio/backend/types/errno" ) -type AppVariables struct { - vars map[string]any - mu sync.RWMutex -} - -func NewAppVariables() *AppVariables { - return &AppVariables{ - vars: make(map[string]any), - } -} - -func (av *AppVariables) Set(key string, value any) { - av.mu.Lock() - av.vars[key] = value - av.mu.Unlock() -} - -func (av *AppVariables) Get(key string) (any, bool) { - av.mu.RLock() - defer av.mu.RUnlock() - - if value, ok := av.vars[key]; ok { - return value, ok - } - return nil, false -} - -type AppVariableStore interface { - GetAppVariableValue(key string) (any, bool) - SetAppVariableValue(key string, value any) -} - type VariableAssigner struct { config *Config } @@ -109,16 +77,16 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s vType := *pair.Left.VariableType switch vType { case vo.GlobalAPP: - err := compose.ProcessState(ctx, func(ctx context.Context, appVarsStore AppVariableStore) error { - if len(pair.Left.FromPath) != 1 { - return fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath) - } - appVarsStore.SetAppVariableValue(pair.Left.FromPath[0], right) - return nil - }) - if err != nil { - return nil, err + appVS := execute.GetAppVarStore(ctx) + if appVS == nil { + return nil, errors.New("exeCtx or AppVarStore not found for variable assigner") } + + if len(pair.Left.FromPath) != 1 { + return nil, fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath) + } + + appVS.Set(pair.Left.FromPath[0], right) case vo.GlobalUser: opts := make([]variable.OptionFn, 0, 1) if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {