fix: app var store can be resumed (#546)

This commit is contained in:
shentongmartin 2025-08-04 20:19:36 +08:00 committed by GitHub
parent f80d4f757b
commit 72656e4fd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 137 additions and 98 deletions

View File

@ -29,8 +29,9 @@ import (
context "context"
reflect "reflect"
plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
schema "github.com/cloudwego/eino/schema"
plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
vo "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
gomock "go.uber.org/mock/gomock"
)
@ -103,6 +104,20 @@ func (mr *MockServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockService)(nil).GetPluginToolsInfo), ctx, req)
}
// UnwrapArrayItemFieldsInVariable mocks base method.
func (m *MockService) UnwrapArrayItemFieldsInVariable(v *vo.Variable) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnwrapArrayItemFieldsInVariable", v)
ret0, _ := ret[0].(error)
return ret0
}
// UnwrapArrayItemFieldsInVariable indicates an expected call of UnwrapArrayItemFieldsInVariable.
func (mr *MockServiceMockRecorder) UnwrapArrayItemFieldsInVariable(v any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapArrayItemFieldsInVariable", reflect.TypeOf((*MockService)(nil).UnwrapArrayItemFieldsInVariable), v)
}
// MockInvokableTool is a mock of InvokableTool interface.
type MockInvokableTool struct {
ctrl *gomock.Controller

View File

@ -655,11 +655,10 @@ func TestKnowledgeNodes(t *testing.T) {
mockKnowledgeOperator.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(rResponse, nil)
mockGlobalAppVarStore := mockvar.NewMockStore(ctrl)
mockGlobalAppVarStore.EXPECT().Get(gomock.Any(), gomock.Any()).Return("v1", nil).AnyTimes()
mockGlobalAppVarStore.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
variable.SetVariableHandler(&variable.Handler{
AppVarStore: mockGlobalAppVarStore,
})
variable.SetVariableHandler(&variable.Handler{AppVarStore: mockGlobalAppVarStore})
mockey.Mock(execute.GetAppVarStore).Return(&execute.AppVariables{Vars: map[string]any{}}).Build()
ctx := t.Context()
ctx = ctxcache.Init(ctx)

View File

@ -34,7 +34,6 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -53,7 +52,6 @@ type State struct {
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
AppVariableStore *variableassigner.AppVariables `json:"variable_app_store,omitempty"`
}
func init() {
@ -85,15 +83,7 @@ func init() {
_ = compose.RegisterSerializableType[vo.SyncPattern]("sync_pattern")
_ = compose.RegisterSerializableType[vo.Locator]("wf_locator")
_ = compose.RegisterSerializableType[vo.BizType]("biz_type")
_ = compose.RegisterSerializableType[*variableassigner.AppVariables]("app_variables")
}
func (s *State) SetAppVariableValue(key string, value any) {
s.AppVariableStore.Set(key, value)
}
func (s *State) GetAppVariableValue(key string) (any, bool) {
return s.AppVariableStore.Get(key)
_ = compose.RegisterSerializableType[*execute.AppVariables]("app_variables")
}
func (s *State) AddQuestion(nodeKey vo.NodeKey, question *qa.Question) {
@ -271,19 +261,6 @@ func (s *State) NodeExecuted(key vo.NodeKey) bool {
func GenState() compose.GenLocalState[*State] {
return func(ctx context.Context) (state *State) {
var parentState *State
_ = compose.ProcessState(ctx, func(ctx context.Context, s *State) error {
parentState = s
return nil
})
var appVariableStore *variableassigner.AppVariables
if parentState == nil {
appVariableStore = variableassigner.NewAppVariables()
} else {
appVariableStore = parentState.AppVariableStore
}
return &State{
Answers: make(map[vo.NodeKey][]string),
Questions: make(map[vo.NodeKey][]*qa.Question),
@ -296,7 +273,6 @@ func GenState() compose.GenLocalState[*State] {
GroupChoices: make(map[vo.NodeKey]map[string]int),
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
LLMToResumeData: make(map[vo.NodeKey]string),
AppVariableStore: appVariableStore,
}
}
}
@ -422,10 +398,9 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
intermediateVarStore := &nodes.ParentIntermediateStore{}
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
opts := make([]variable.OptionFn, 0, 1)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
var exeCtx *execute.Context
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
@ -452,13 +427,20 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
if exeCtx == nil || exeCtx.AppVarStore == nil {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
} else {
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
exeCtx.AppVarStore.Set(path, v)
}
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
@ -494,15 +476,18 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
var (
variables = make(map[string]any)
opts = make([]variable.OptionFn, 0, 1)
exeCfg = execute.GetExeCtx(ctx).RootCtx.ExeCfg
exeCtx *execute.Context
)
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := exeCtx.RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
}
for _, input := range vars {
if input == nil {
@ -518,13 +503,20 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
if exeCtx == nil || exeCtx.AppVarStore == nil {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
} else {
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
exeCtx.AppVarStore.Set(path, v)
}
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
@ -776,7 +768,8 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
opts := make([]variable.OptionFn, 0, 1)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
var exeCtx *execute.Context
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
@ -801,13 +794,20 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
if exeCtx == nil || exeCtx.AppVarStore == nil {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
} else {
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
exeCtx.AppVarStore.Set(path, v)
}
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
@ -845,9 +845,10 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand
var (
variables = make(map[string]any)
opts = make([]variable.OptionFn, 0, 1)
exeCtx *execute.Context
)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
@ -869,13 +870,20 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
if exeCtx == nil || exeCtx.AppVarStore == nil {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
} else {
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
exeCtx.AppVarStore.Set(path, v)
}
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)

View File

@ -22,6 +22,7 @@ import (
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/cloudwego/eino/compose"
@ -45,6 +46,8 @@ type Context struct {
StartTime int64 // UnixMilli
CheckPointID string
AppVarStore *AppVariables
}
type RootCtx struct {
@ -106,12 +109,15 @@ func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Contex
}
storedCtx.ResumeEvent = h.resumeEvent
currentC := GetExeCtx(ctx)
if currentC != nil {
// restore the parent-child relationship between token collectors
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
currentTokenCollector := currentC.TokenCollector
storedCtx.TokenCollector.Parent = currentTokenCollector
}
// restore the parent-child relationship between token collectors
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
currentC := GetExeCtx(ctx)
currentTokenCollector := currentC.TokenCollector
storedCtx.TokenCollector.Parent = currentTokenCollector
storedCtx.AppVarStore = currentC.AppVarStore
}
return context.WithValue(ctx, contextKey{}, storedCtx), nil
@ -150,13 +156,16 @@ func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
}
currentC := GetExeCtx(ctx)
// restore the parent-child relationship between token collectors
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
currentC := GetExeCtx(ctx)
currentTokenCollector := currentC.TokenCollector
storedCtx.TokenCollector.Parent = currentTokenCollector
}
storedCtx.AppVarStore = currentC.AppVarStore
storedCtx.NodeCtx.CurrentRetryCount = 0
return context.WithValue(ctx, contextKey{}, storedCtx), nil
@ -184,6 +193,7 @@ func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context
existingC := GetExeCtx(ctx)
if existingC != nil {
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
storedCtx.AppVarStore = existingC.AppVarStore
}
// restore the parent-child relationship between token collectors
@ -213,6 +223,7 @@ func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context
TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector),
StartTime: time.Now().UnixMilli(),
AppVarStore: NewAppVariables(),
}
if h.requireCheckpoint {
@ -266,6 +277,7 @@ func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireChec
TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector),
CheckPointID: newCheckpointID,
StartTime: time.Now().UnixMilli(),
AppVarStore: c.AppVarStore,
}
if requireCheckpoint {
@ -308,6 +320,7 @@ func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string,
BatchInfo: c.BatchInfo,
StartTime: time.Now().UnixMilli(),
CheckPointID: c.CheckPointID,
AppVarStore: c.AppVarStore,
}
if c.NodeCtx == nil { // node within top level workflow, also not under composite node
@ -354,6 +367,7 @@ func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string
CompositeNodeKey: c.NodeCtx.NodeKey,
},
CheckPointID: newCheckpointID,
AppVarStore: c.AppVarStore,
}), newCheckpointID
}
@ -363,3 +377,38 @@ type ExeContextStore interface {
GetWorkflowCtx() (*Context, bool, error)
SetWorkflowCtx(value *Context) error
}
type AppVariables struct {
Vars map[string]any
mu sync.RWMutex
}
func NewAppVariables() *AppVariables {
return &AppVariables{
Vars: make(map[string]any),
}
}
func (av *AppVariables) Set(key string, value any) {
av.mu.Lock()
av.Vars[key] = value
av.mu.Unlock()
}
func (av *AppVariables) Get(key string) (any, bool) {
av.mu.RLock()
defer av.mu.RUnlock()
if value, ok := av.Vars[key]; ok {
return value, ok
}
return nil, false
}
func GetAppVarStore(ctx context.Context) *AppVariables {
c := ctx.Value(contextKey{})
if c == nil {
return nil
}
return c.(*Context).AppVarStore
}

View File

@ -18,9 +18,9 @@ package variableassigner
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/cloudwego/eino/compose"
@ -32,38 +32,6 @@ import (
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type AppVariables struct {
vars map[string]any
mu sync.RWMutex
}
func NewAppVariables() *AppVariables {
return &AppVariables{
vars: make(map[string]any),
}
}
func (av *AppVariables) Set(key string, value any) {
av.mu.Lock()
av.vars[key] = value
av.mu.Unlock()
}
func (av *AppVariables) Get(key string) (any, bool) {
av.mu.RLock()
defer av.mu.RUnlock()
if value, ok := av.vars[key]; ok {
return value, ok
}
return nil, false
}
type AppVariableStore interface {
GetAppVariableValue(key string) (any, bool)
SetAppVariableValue(key string, value any)
}
type VariableAssigner struct {
config *Config
}
@ -109,16 +77,16 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
vType := *pair.Left.VariableType
switch vType {
case vo.GlobalAPP:
err := compose.ProcessState(ctx, func(ctx context.Context, appVarsStore AppVariableStore) error {
if len(pair.Left.FromPath) != 1 {
return fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
}
appVarsStore.SetAppVariableValue(pair.Left.FromPath[0], right)
return nil
})
if err != nil {
return nil, err
appVS := execute.GetAppVarStore(ctx)
if appVS == nil {
return nil, errors.New("exeCtx or AppVarStore not found for variable assigner")
}
if len(pair.Left.FromPath) != 1 {
return nil, fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
}
appVS.Set(pair.Left.FromPath[0], right)
case vo.GlobalUser:
opts := make([]variable.OptionFn, 0, 1)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {