fix: app var store can be resumed (#546)
This commit is contained in:
parent
f80d4f757b
commit
72656e4fd1
|
|
@ -29,8 +29,9 @@ import (
|
|||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
schema "github.com/cloudwego/eino/schema"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
vo "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
|
@ -103,6 +104,20 @@ func (mr *MockServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockService)(nil).GetPluginToolsInfo), ctx, req)
|
||||
}
|
||||
|
||||
// UnwrapArrayItemFieldsInVariable mocks base method.
|
||||
func (m *MockService) UnwrapArrayItemFieldsInVariable(v *vo.Variable) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnwrapArrayItemFieldsInVariable", v)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnwrapArrayItemFieldsInVariable indicates an expected call of UnwrapArrayItemFieldsInVariable.
|
||||
func (mr *MockServiceMockRecorder) UnwrapArrayItemFieldsInVariable(v any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapArrayItemFieldsInVariable", reflect.TypeOf((*MockService)(nil).UnwrapArrayItemFieldsInVariable), v)
|
||||
}
|
||||
|
||||
// MockInvokableTool is a mock of InvokableTool interface.
|
||||
type MockInvokableTool struct {
|
||||
ctrl *gomock.Controller
|
||||
|
|
|
|||
|
|
@ -655,11 +655,10 @@ func TestKnowledgeNodes(t *testing.T) {
|
|||
mockKnowledgeOperator.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(rResponse, nil)
|
||||
mockGlobalAppVarStore := mockvar.NewMockStore(ctrl)
|
||||
mockGlobalAppVarStore.EXPECT().Get(gomock.Any(), gomock.Any()).Return("v1", nil).AnyTimes()
|
||||
mockGlobalAppVarStore.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
variable.SetVariableHandler(&variable.Handler{
|
||||
AppVarStore: mockGlobalAppVarStore,
|
||||
})
|
||||
variable.SetVariableHandler(&variable.Handler{AppVarStore: mockGlobalAppVarStore})
|
||||
|
||||
mockey.Mock(execute.GetAppVarStore).Return(&execute.AppVariables{Vars: map[string]any{}}).Build()
|
||||
|
||||
ctx := t.Context()
|
||||
ctx = ctxcache.Init(ctx)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
|
|
@ -53,7 +52,6 @@ type State struct {
|
|||
|
||||
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
|
||||
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
|
||||
AppVariableStore *variableassigner.AppVariables `json:"variable_app_store,omitempty"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
@ -85,15 +83,7 @@ func init() {
|
|||
_ = compose.RegisterSerializableType[vo.SyncPattern]("sync_pattern")
|
||||
_ = compose.RegisterSerializableType[vo.Locator]("wf_locator")
|
||||
_ = compose.RegisterSerializableType[vo.BizType]("biz_type")
|
||||
_ = compose.RegisterSerializableType[*variableassigner.AppVariables]("app_variables")
|
||||
}
|
||||
|
||||
func (s *State) SetAppVariableValue(key string, value any) {
|
||||
s.AppVariableStore.Set(key, value)
|
||||
}
|
||||
|
||||
func (s *State) GetAppVariableValue(key string) (any, bool) {
|
||||
return s.AppVariableStore.Get(key)
|
||||
_ = compose.RegisterSerializableType[*execute.AppVariables]("app_variables")
|
||||
}
|
||||
|
||||
func (s *State) AddQuestion(nodeKey vo.NodeKey, question *qa.Question) {
|
||||
|
|
@ -271,19 +261,6 @@ func (s *State) NodeExecuted(key vo.NodeKey) bool {
|
|||
|
||||
func GenState() compose.GenLocalState[*State] {
|
||||
return func(ctx context.Context) (state *State) {
|
||||
var parentState *State
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, s *State) error {
|
||||
parentState = s
|
||||
return nil
|
||||
})
|
||||
|
||||
var appVariableStore *variableassigner.AppVariables
|
||||
if parentState == nil {
|
||||
appVariableStore = variableassigner.NewAppVariables()
|
||||
} else {
|
||||
appVariableStore = parentState.AppVariableStore
|
||||
}
|
||||
|
||||
return &State{
|
||||
Answers: make(map[vo.NodeKey][]string),
|
||||
Questions: make(map[vo.NodeKey][]*qa.Question),
|
||||
|
|
@ -296,7 +273,6 @@ func GenState() compose.GenLocalState[*State] {
|
|||
GroupChoices: make(map[vo.NodeKey]map[string]int),
|
||||
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
|
||||
LLMToResumeData: make(map[vo.NodeKey]string),
|
||||
AppVariableStore: appVariableStore,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -422,10 +398,9 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
|||
intermediateVarStore := &nodes.ParentIntermediateStore{}
|
||||
|
||||
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
var exeCtx *execute.Context
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
|
|
@ -452,13 +427,20 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
|||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
|
|
@ -494,15 +476,18 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
|||
var (
|
||||
variables = make(map[string]any)
|
||||
opts = make([]variable.OptionFn, 0, 1)
|
||||
exeCfg = execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
exeCtx *execute.Context
|
||||
)
|
||||
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := exeCtx.RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
|
||||
for _, input := range vars {
|
||||
if input == nil {
|
||||
|
|
@ -518,13 +503,20 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
|||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
|
|
@ -776,7 +768,8 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
|||
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
var exeCtx *execute.Context
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
|
|
@ -801,13 +794,20 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
|||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
|
|
@ -845,9 +845,10 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand
|
|||
var (
|
||||
variables = make(map[string]any)
|
||||
opts = make([]variable.OptionFn, 0, 1)
|
||||
exeCtx *execute.Context
|
||||
)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
|
|
@ -869,13 +870,20 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand
|
|||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
|
@ -45,6 +46,8 @@ type Context struct {
|
|||
StartTime int64 // UnixMilli
|
||||
|
||||
CheckPointID string
|
||||
|
||||
AppVarStore *AppVariables
|
||||
}
|
||||
|
||||
type RootCtx struct {
|
||||
|
|
@ -106,14 +109,17 @@ func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Contex
|
|||
}
|
||||
|
||||
storedCtx.ResumeEvent = h.resumeEvent
|
||||
|
||||
currentC := GetExeCtx(ctx)
|
||||
if currentC != nil {
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.AppVarStore = currentC.AppVarStore
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
}
|
||||
|
||||
|
|
@ -150,13 +156,16 @@ func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity
|
|||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
}
|
||||
|
||||
currentC := GetExeCtx(ctx)
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.AppVarStore = currentC.AppVarStore
|
||||
|
||||
storedCtx.NodeCtx.CurrentRetryCount = 0
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
|
|
@ -184,6 +193,7 @@ func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context
|
|||
existingC := GetExeCtx(ctx)
|
||||
if existingC != nil {
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
storedCtx.AppVarStore = existingC.AppVarStore
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
|
|
@ -213,6 +223,7 @@ func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context
|
|||
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector),
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
AppVarStore: NewAppVariables(),
|
||||
}
|
||||
|
||||
if h.requireCheckpoint {
|
||||
|
|
@ -266,6 +277,7 @@ func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireChec
|
|||
TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector),
|
||||
CheckPointID: newCheckpointID,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
AppVarStore: c.AppVarStore,
|
||||
}
|
||||
|
||||
if requireCheckpoint {
|
||||
|
|
@ -308,6 +320,7 @@ func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string,
|
|||
BatchInfo: c.BatchInfo,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
CheckPointID: c.CheckPointID,
|
||||
AppVarStore: c.AppVarStore,
|
||||
}
|
||||
|
||||
if c.NodeCtx == nil { // node within top level workflow, also not under composite node
|
||||
|
|
@ -354,6 +367,7 @@ func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string
|
|||
CompositeNodeKey: c.NodeCtx.NodeKey,
|
||||
},
|
||||
CheckPointID: newCheckpointID,
|
||||
AppVarStore: c.AppVarStore,
|
||||
}), newCheckpointID
|
||||
}
|
||||
|
||||
|
|
@ -363,3 +377,38 @@ type ExeContextStore interface {
|
|||
GetWorkflowCtx() (*Context, bool, error)
|
||||
SetWorkflowCtx(value *Context) error
|
||||
}
|
||||
|
||||
type AppVariables struct {
|
||||
Vars map[string]any
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAppVariables() *AppVariables {
|
||||
return &AppVariables{
|
||||
Vars: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (av *AppVariables) Set(key string, value any) {
|
||||
av.mu.Lock()
|
||||
av.Vars[key] = value
|
||||
av.mu.Unlock()
|
||||
}
|
||||
|
||||
func (av *AppVariables) Get(key string) (any, bool) {
|
||||
av.mu.RLock()
|
||||
defer av.mu.RUnlock()
|
||||
|
||||
if value, ok := av.Vars[key]; ok {
|
||||
return value, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func GetAppVarStore(ctx context.Context) *AppVariables {
|
||||
c := ctx.Value(contextKey{})
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.(*Context).AppVarStore
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ package variableassigner
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
|
|
@ -32,38 +32,6 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type AppVariables struct {
|
||||
vars map[string]any
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAppVariables() *AppVariables {
|
||||
return &AppVariables{
|
||||
vars: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (av *AppVariables) Set(key string, value any) {
|
||||
av.mu.Lock()
|
||||
av.vars[key] = value
|
||||
av.mu.Unlock()
|
||||
}
|
||||
|
||||
func (av *AppVariables) Get(key string) (any, bool) {
|
||||
av.mu.RLock()
|
||||
defer av.mu.RUnlock()
|
||||
|
||||
if value, ok := av.vars[key]; ok {
|
||||
return value, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type AppVariableStore interface {
|
||||
GetAppVariableValue(key string) (any, bool)
|
||||
SetAppVariableValue(key string, value any)
|
||||
}
|
||||
|
||||
type VariableAssigner struct {
|
||||
config *Config
|
||||
}
|
||||
|
|
@ -109,16 +77,16 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
|
|||
vType := *pair.Left.VariableType
|
||||
switch vType {
|
||||
case vo.GlobalAPP:
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, appVarsStore AppVariableStore) error {
|
||||
appVS := execute.GetAppVarStore(ctx)
|
||||
if appVS == nil {
|
||||
return nil, errors.New("exeCtx or AppVarStore not found for variable assigner")
|
||||
}
|
||||
|
||||
if len(pair.Left.FromPath) != 1 {
|
||||
return fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
|
||||
}
|
||||
appVarsStore.SetAppVariableValue(pair.Left.FromPath[0], right)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
|
||||
}
|
||||
|
||||
appVS.Set(pair.Left.FromPath[0], right)
|
||||
case vo.GlobalUser:
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
|
|
|
|||
Loading…
Reference in New Issue