feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
365
backend/domain/workflow/internal/execute/context.go
Normal file
365
backend/domain/workflow/internal/execute/context.go
Normal file
@@ -0,0 +1,365 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
RootCtx
|
||||
|
||||
*SubWorkflowCtx
|
||||
|
||||
*NodeCtx
|
||||
|
||||
*BatchInfo
|
||||
|
||||
TokenCollector *TokenCollector
|
||||
|
||||
StartTime int64 // UnixMilli
|
||||
|
||||
CheckPointID string
|
||||
}
|
||||
|
||||
type RootCtx struct {
|
||||
RootWorkflowBasic *entity.WorkflowBasic
|
||||
RootExecuteID int64
|
||||
ResumeEvent *entity.InterruptEvent
|
||||
ExeCfg vo.ExecuteConfig
|
||||
}
|
||||
|
||||
type SubWorkflowCtx struct {
|
||||
SubWorkflowBasic *entity.WorkflowBasic
|
||||
SubExecuteID int64
|
||||
}
|
||||
|
||||
type NodeCtx struct {
|
||||
NodeKey vo.NodeKey
|
||||
NodeExecuteID int64
|
||||
NodeName string
|
||||
NodeType entity.NodeType
|
||||
NodePath []string
|
||||
TerminatePlan *vo.TerminatePlan
|
||||
|
||||
ResumingEvent *entity.InterruptEvent
|
||||
SubWorkflowExeID int64 // if this node is subworkflow node, the execute id of the sub workflow
|
||||
|
||||
CurrentRetryCount int
|
||||
}
|
||||
|
||||
type BatchInfo struct {
|
||||
Index int
|
||||
Items map[string]any
|
||||
CompositeNodeKey vo.NodeKey
|
||||
}
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Context, error) {
|
||||
var storedCtx *Context
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
|
||||
var e error
|
||||
storedCtx, _, e = state.GetWorkflowCtx()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
if storedCtx == nil {
|
||||
return ctx, errors.New("stored workflow context is nil")
|
||||
}
|
||||
|
||||
storedCtx.ResumeEvent = h.resumeEvent
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
}
|
||||
|
||||
func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity.InterruptEvent,
|
||||
exactlyResuming bool) (context.Context, error) {
|
||||
var storedCtx *Context
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
var e error
|
||||
storedCtx, _, e = state.GetNodeCtx(nodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
if storedCtx == nil {
|
||||
return ctx, errors.New("stored node context is nil")
|
||||
}
|
||||
|
||||
if exactlyResuming {
|
||||
storedCtx.NodeCtx.ResumingEvent = resumeEvent
|
||||
} else {
|
||||
storedCtx.NodeCtx.ResumingEvent = nil
|
||||
}
|
||||
|
||||
existingC := GetExeCtx(ctx)
|
||||
if existingC != nil {
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.NodeCtx.CurrentRetryCount = 0
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
}
|
||||
|
||||
func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context, bool) {
|
||||
var storedCtx *Context
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
var e error
|
||||
storedCtx, _, e = state.GetNodeCtx(nodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil || storedCtx == nil {
|
||||
return ctx, false
|
||||
}
|
||||
|
||||
storedCtx.NodeCtx.ResumingEvent = nil
|
||||
|
||||
existingC := GetExeCtx(ctx)
|
||||
if existingC != nil {
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil && existingC != nil {
|
||||
currentTokenCollector := existingC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.NodeCtx.CurrentRetryCount = 0
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), true
|
||||
}
|
||||
|
||||
func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context, error) {
|
||||
var parentTokenCollector *TokenCollector
|
||||
if currentC := GetExeCtx(ctx); currentC != nil {
|
||||
parentTokenCollector = currentC.TokenCollector
|
||||
}
|
||||
|
||||
rootExeCtx := &Context{
|
||||
RootCtx: RootCtx{
|
||||
RootWorkflowBasic: h.rootWorkflowBasic,
|
||||
RootExecuteID: h.rootExecuteID,
|
||||
ResumeEvent: h.resumeEvent,
|
||||
ExeCfg: h.exeCfg,
|
||||
},
|
||||
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector),
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if h.requireCheckpoint {
|
||||
rootExeCtx.CheckPointID = strconv.FormatInt(h.rootExecuteID, 10)
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
return state.SetWorkflowCtx(rootExeCtx)
|
||||
})
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, rootExeCtx), nil
|
||||
}
|
||||
|
||||
func GetExeCtx(ctx context.Context) *Context {
|
||||
c := ctx.Value(contextKey{})
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.(*Context)
|
||||
}
|
||||
|
||||
func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireCheckpoint bool) (context.Context, error) {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
subExecuteID, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var newCheckpointID string
|
||||
if len(c.CheckPointID) > 0 {
|
||||
newCheckpointID = c.CheckPointID + "_" + strconv.FormatInt(subExecuteID, 10)
|
||||
}
|
||||
|
||||
newC := &Context{
|
||||
RootCtx: c.RootCtx,
|
||||
SubWorkflowCtx: &SubWorkflowCtx{
|
||||
SubWorkflowBasic: wb,
|
||||
SubExecuteID: subExecuteID,
|
||||
},
|
||||
NodeCtx: c.NodeCtx,
|
||||
BatchInfo: c.BatchInfo,
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector),
|
||||
CheckPointID: newCheckpointID,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if requireCheckpoint {
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
return state.SetWorkflowCtx(newC)
|
||||
})
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
}
|
||||
|
||||
newC.NodeCtx.SubWorkflowExeID = subExecuteID
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, newC), nil
|
||||
}
|
||||
|
||||
func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string, nodeType entity.NodeType, plan *vo.TerminatePlan) (context.Context, error) {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return ctx, nil
|
||||
}
|
||||
nodeExecuteID, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newC := &Context{
|
||||
RootCtx: c.RootCtx,
|
||||
SubWorkflowCtx: c.SubWorkflowCtx,
|
||||
NodeCtx: &NodeCtx{
|
||||
NodeKey: nodeKey,
|
||||
NodeExecuteID: nodeExecuteID,
|
||||
NodeName: nodeName,
|
||||
NodeType: nodeType,
|
||||
TerminatePlan: plan,
|
||||
},
|
||||
BatchInfo: c.BatchInfo,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
CheckPointID: c.CheckPointID,
|
||||
}
|
||||
|
||||
if c.NodeCtx == nil { // node within top level workflow, also not under composite node
|
||||
newC.NodeCtx.NodePath = []string{string(nodeKey)}
|
||||
} else {
|
||||
if c.BatchInfo == nil {
|
||||
newC.NodeCtx.NodePath = append(c.NodeCtx.NodePath, string(nodeKey))
|
||||
} else {
|
||||
newC.NodeCtx.NodePath = append(c.NodeCtx.NodePath, InterruptEventIndexPrefix+strconv.Itoa(c.BatchInfo.Index), string(nodeKey))
|
||||
}
|
||||
}
|
||||
|
||||
tc := c.TokenCollector
|
||||
if entity.NodeMetaByNodeType(nodeType).MayUseChatModel {
|
||||
tc = newTokenCollector(strings.Join(append([]string{string(newC.NodeType)}, newC.NodeCtx.NodePath...), "."), c.TokenCollector)
|
||||
}
|
||||
newC.TokenCollector = tc
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, newC), nil
|
||||
}
|
||||
|
||||
func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string]any) (context.Context, string) {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return ctx, ""
|
||||
}
|
||||
var newCheckpointID string
|
||||
if len(c.CheckPointID) > 0 {
|
||||
newCheckpointID = c.CheckPointID
|
||||
if c.SubWorkflowCtx != nil {
|
||||
newCheckpointID += "_" + strconv.Itoa(int(c.SubWorkflowCtx.SubExecuteID))
|
||||
}
|
||||
newCheckpointID += "_" + string(c.NodeCtx.NodeKey)
|
||||
newCheckpointID += "_" + strconv.Itoa(index)
|
||||
}
|
||||
return context.WithValue(ctx, contextKey{}, &Context{
|
||||
RootCtx: c.RootCtx,
|
||||
SubWorkflowCtx: c.SubWorkflowCtx,
|
||||
NodeCtx: c.NodeCtx,
|
||||
TokenCollector: c.TokenCollector,
|
||||
BatchInfo: &BatchInfo{
|
||||
Index: index,
|
||||
Items: items,
|
||||
CompositeNodeKey: c.NodeCtx.NodeKey,
|
||||
},
|
||||
CheckPointID: newCheckpointID,
|
||||
}), newCheckpointID
|
||||
}
|
||||
|
||||
type ExeContextStore interface {
|
||||
GetNodeCtx(key vo.NodeKey) (*Context, bool, error)
|
||||
SetNodeCtx(key vo.NodeKey, value *Context) error
|
||||
GetWorkflowCtx() (*Context, bool, error)
|
||||
SetWorkflowCtx(value *Context) error
|
||||
}
|
||||
Reference in New Issue
Block a user