366 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			366 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package execute
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/cloudwego/eino/compose"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
 | 
						|
)
 | 
						|
 | 
						|
type Context struct {
 | 
						|
	RootCtx
 | 
						|
 | 
						|
	*SubWorkflowCtx
 | 
						|
 | 
						|
	*NodeCtx
 | 
						|
 | 
						|
	*BatchInfo
 | 
						|
 | 
						|
	TokenCollector *TokenCollector
 | 
						|
 | 
						|
	StartTime int64 // UnixMilli
 | 
						|
 | 
						|
	CheckPointID string
 | 
						|
}
 | 
						|
 | 
						|
type RootCtx struct {
 | 
						|
	RootWorkflowBasic *entity.WorkflowBasic
 | 
						|
	RootExecuteID     int64
 | 
						|
	ResumeEvent       *entity.InterruptEvent
 | 
						|
	ExeCfg            vo.ExecuteConfig
 | 
						|
}
 | 
						|
 | 
						|
type SubWorkflowCtx struct {
 | 
						|
	SubWorkflowBasic *entity.WorkflowBasic
 | 
						|
	SubExecuteID     int64
 | 
						|
}
 | 
						|
 | 
						|
type NodeCtx struct {
 | 
						|
	NodeKey       vo.NodeKey
 | 
						|
	NodeExecuteID int64
 | 
						|
	NodeName      string
 | 
						|
	NodeType      entity.NodeType
 | 
						|
	NodePath      []string
 | 
						|
	TerminatePlan *vo.TerminatePlan
 | 
						|
 | 
						|
	ResumingEvent    *entity.InterruptEvent
 | 
						|
	SubWorkflowExeID int64 // if this node is subworkflow node, the execute id of the sub workflow
 | 
						|
 | 
						|
	CurrentRetryCount int
 | 
						|
}
 | 
						|
 | 
						|
type BatchInfo struct {
 | 
						|
	Index            int
 | 
						|
	Items            map[string]any
 | 
						|
	CompositeNodeKey vo.NodeKey
 | 
						|
}
 | 
						|
 | 
						|
type contextKey struct{}
 | 
						|
 | 
						|
func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Context, error) {
 | 
						|
	var storedCtx *Context
 | 
						|
	err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
 | 
						|
		if state == nil {
 | 
						|
			return errors.New("state is nil")
 | 
						|
		}
 | 
						|
 | 
						|
		var e error
 | 
						|
		storedCtx, _, e = state.GetWorkflowCtx()
 | 
						|
		if e != nil {
 | 
						|
			return e
 | 
						|
		}
 | 
						|
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return ctx, err
 | 
						|
	}
 | 
						|
 | 
						|
	if storedCtx == nil {
 | 
						|
		return ctx, errors.New("stored workflow context is nil")
 | 
						|
	}
 | 
						|
 | 
						|
	storedCtx.ResumeEvent = h.resumeEvent
 | 
						|
 | 
						|
	// restore the parent-child relationship between token collectors
 | 
						|
	if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
 | 
						|
		currentC := GetExeCtx(ctx)
 | 
						|
		currentTokenCollector := currentC.TokenCollector
 | 
						|
		storedCtx.TokenCollector.Parent = currentTokenCollector
 | 
						|
	}
 | 
						|
 | 
						|
	return context.WithValue(ctx, contextKey{}, storedCtx), nil
 | 
						|
}
 | 
						|
 | 
						|
func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity.InterruptEvent,
 | 
						|
	exactlyResuming bool) (context.Context, error) {
 | 
						|
	var storedCtx *Context
 | 
						|
	err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
 | 
						|
		if state == nil {
 | 
						|
			return errors.New("state is nil")
 | 
						|
		}
 | 
						|
		var e error
 | 
						|
		storedCtx, _, e = state.GetNodeCtx(nodeKey)
 | 
						|
		if e != nil {
 | 
						|
			return e
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return ctx, err
 | 
						|
	}
 | 
						|
 | 
						|
	if storedCtx == nil {
 | 
						|
		return ctx, errors.New("stored node context is nil")
 | 
						|
	}
 | 
						|
 | 
						|
	if exactlyResuming {
 | 
						|
		storedCtx.NodeCtx.ResumingEvent = resumeEvent
 | 
						|
	} else {
 | 
						|
		storedCtx.NodeCtx.ResumingEvent = nil
 | 
						|
	}
 | 
						|
 | 
						|
	existingC := GetExeCtx(ctx)
 | 
						|
	if existingC != nil {
 | 
						|
		storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
 | 
						|
	}
 | 
						|
 | 
						|
	// restore the parent-child relationship between token collectors
 | 
						|
	if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
 | 
						|
		currentC := GetExeCtx(ctx)
 | 
						|
		currentTokenCollector := currentC.TokenCollector
 | 
						|
		storedCtx.TokenCollector.Parent = currentTokenCollector
 | 
						|
	}
 | 
						|
 | 
						|
	storedCtx.NodeCtx.CurrentRetryCount = 0
 | 
						|
 | 
						|
	return context.WithValue(ctx, contextKey{}, storedCtx), nil
 | 
						|
}
 | 
						|
 | 
						|
func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context, bool) {
 | 
						|
	var storedCtx *Context
 | 
						|
	err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
 | 
						|
		if state == nil {
 | 
						|
			return errors.New("state is nil")
 | 
						|
		}
 | 
						|
		var e error
 | 
						|
		storedCtx, _, e = state.GetNodeCtx(nodeKey)
 | 
						|
		if e != nil {
 | 
						|
			return e
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
	if err != nil || storedCtx == nil {
 | 
						|
		return ctx, false
 | 
						|
	}
 | 
						|
 | 
						|
	storedCtx.NodeCtx.ResumingEvent = nil
 | 
						|
 | 
						|
	existingC := GetExeCtx(ctx)
 | 
						|
	if existingC != nil {
 | 
						|
		storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
 | 
						|
	}
 | 
						|
 | 
						|
	// restore the parent-child relationship between token collectors
 | 
						|
	if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil && existingC != nil {
 | 
						|
		currentTokenCollector := existingC.TokenCollector
 | 
						|
		storedCtx.TokenCollector.Parent = currentTokenCollector
 | 
						|
	}
 | 
						|
 | 
						|
	storedCtx.NodeCtx.CurrentRetryCount = 0
 | 
						|
 | 
						|
	return context.WithValue(ctx, contextKey{}, storedCtx), true
 | 
						|
}
 | 
						|
 | 
						|
func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context, error) {
 | 
						|
	var parentTokenCollector *TokenCollector
 | 
						|
	if currentC := GetExeCtx(ctx); currentC != nil {
 | 
						|
		parentTokenCollector = currentC.TokenCollector
 | 
						|
	}
 | 
						|
 | 
						|
	rootExeCtx := &Context{
 | 
						|
		RootCtx: RootCtx{
 | 
						|
			RootWorkflowBasic: h.rootWorkflowBasic,
 | 
						|
			RootExecuteID:     h.rootExecuteID,
 | 
						|
			ResumeEvent:       h.resumeEvent,
 | 
						|
			ExeCfg:            h.exeCfg,
 | 
						|
		},
 | 
						|
 | 
						|
		TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector),
 | 
						|
		StartTime:      time.Now().UnixMilli(),
 | 
						|
	}
 | 
						|
 | 
						|
	if h.requireCheckpoint {
 | 
						|
		rootExeCtx.CheckPointID = strconv.FormatInt(h.rootExecuteID, 10)
 | 
						|
		err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
 | 
						|
			if state == nil {
 | 
						|
				return errors.New("state is nil")
 | 
						|
			}
 | 
						|
			return state.SetWorkflowCtx(rootExeCtx)
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return ctx, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return context.WithValue(ctx, contextKey{}, rootExeCtx), nil
 | 
						|
}
 | 
						|
 | 
						|
func GetExeCtx(ctx context.Context) *Context {
 | 
						|
	c := ctx.Value(contextKey{})
 | 
						|
	if c == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	return c.(*Context)
 | 
						|
}
 | 
						|
 | 
						|
func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireCheckpoint bool) (context.Context, error) {
 | 
						|
	c := GetExeCtx(ctx)
 | 
						|
	if c == nil {
 | 
						|
		return ctx, nil
 | 
						|
	}
 | 
						|
 | 
						|
	subExecuteID, err := workflow.GetRepository().GenID(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	var newCheckpointID string
 | 
						|
	if len(c.CheckPointID) > 0 {
 | 
						|
		newCheckpointID = c.CheckPointID + "_" + strconv.FormatInt(subExecuteID, 10)
 | 
						|
	}
 | 
						|
 | 
						|
	newC := &Context{
 | 
						|
		RootCtx: c.RootCtx,
 | 
						|
		SubWorkflowCtx: &SubWorkflowCtx{
 | 
						|
			SubWorkflowBasic: wb,
 | 
						|
			SubExecuteID:     subExecuteID,
 | 
						|
		},
 | 
						|
		NodeCtx:        c.NodeCtx,
 | 
						|
		BatchInfo:      c.BatchInfo,
 | 
						|
		TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector),
 | 
						|
		CheckPointID:   newCheckpointID,
 | 
						|
		StartTime:      time.Now().UnixMilli(),
 | 
						|
	}
 | 
						|
 | 
						|
	if requireCheckpoint {
 | 
						|
		err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
 | 
						|
			if state == nil {
 | 
						|
				return errors.New("state is nil")
 | 
						|
			}
 | 
						|
			return state.SetWorkflowCtx(newC)
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return ctx, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	newC.NodeCtx.SubWorkflowExeID = subExecuteID
 | 
						|
 | 
						|
	return context.WithValue(ctx, contextKey{}, newC), nil
 | 
						|
}
 | 
						|
 | 
						|
func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string, nodeType entity.NodeType, plan *vo.TerminatePlan) (context.Context, error) {
 | 
						|
	c := GetExeCtx(ctx)
 | 
						|
	if c == nil {
 | 
						|
		return ctx, nil
 | 
						|
	}
 | 
						|
	nodeExecuteID, err := workflow.GetRepository().GenID(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	newC := &Context{
 | 
						|
		RootCtx:        c.RootCtx,
 | 
						|
		SubWorkflowCtx: c.SubWorkflowCtx,
 | 
						|
		NodeCtx: &NodeCtx{
 | 
						|
			NodeKey:       nodeKey,
 | 
						|
			NodeExecuteID: nodeExecuteID,
 | 
						|
			NodeName:      nodeName,
 | 
						|
			NodeType:      nodeType,
 | 
						|
			TerminatePlan: plan,
 | 
						|
		},
 | 
						|
		BatchInfo:    c.BatchInfo,
 | 
						|
		StartTime:    time.Now().UnixMilli(),
 | 
						|
		CheckPointID: c.CheckPointID,
 | 
						|
	}
 | 
						|
 | 
						|
	if c.NodeCtx == nil { // node within top level workflow, also not under composite node
 | 
						|
		newC.NodeCtx.NodePath = []string{string(nodeKey)}
 | 
						|
	} else {
 | 
						|
		if c.BatchInfo == nil {
 | 
						|
			newC.NodeCtx.NodePath = append(c.NodeCtx.NodePath, string(nodeKey))
 | 
						|
		} else {
 | 
						|
			newC.NodeCtx.NodePath = append(c.NodeCtx.NodePath, InterruptEventIndexPrefix+strconv.Itoa(c.BatchInfo.Index), string(nodeKey))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	tc := c.TokenCollector
 | 
						|
	if entity.NodeMetaByNodeType(nodeType).MayUseChatModel {
 | 
						|
		tc = newTokenCollector(strings.Join(append([]string{string(newC.NodeType)}, newC.NodeCtx.NodePath...), "."), c.TokenCollector)
 | 
						|
	}
 | 
						|
	newC.TokenCollector = tc
 | 
						|
 | 
						|
	return context.WithValue(ctx, contextKey{}, newC), nil
 | 
						|
}
 | 
						|
 | 
						|
func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string]any) (context.Context, string) {
 | 
						|
	c := GetExeCtx(ctx)
 | 
						|
	if c == nil {
 | 
						|
		return ctx, ""
 | 
						|
	}
 | 
						|
	var newCheckpointID string
 | 
						|
	if len(c.CheckPointID) > 0 {
 | 
						|
		newCheckpointID = c.CheckPointID
 | 
						|
		if c.SubWorkflowCtx != nil {
 | 
						|
			newCheckpointID += "_" + strconv.Itoa(int(c.SubWorkflowCtx.SubExecuteID))
 | 
						|
		}
 | 
						|
		newCheckpointID += "_" + string(c.NodeCtx.NodeKey)
 | 
						|
		newCheckpointID += "_" + strconv.Itoa(index)
 | 
						|
	}
 | 
						|
	return context.WithValue(ctx, contextKey{}, &Context{
 | 
						|
		RootCtx:        c.RootCtx,
 | 
						|
		SubWorkflowCtx: c.SubWorkflowCtx,
 | 
						|
		NodeCtx:        c.NodeCtx,
 | 
						|
		TokenCollector: c.TokenCollector,
 | 
						|
		BatchInfo: &BatchInfo{
 | 
						|
			Index:            index,
 | 
						|
			Items:            items,
 | 
						|
			CompositeNodeKey: c.NodeCtx.NodeKey,
 | 
						|
		},
 | 
						|
		CheckPointID: newCheckpointID,
 | 
						|
	}), newCheckpointID
 | 
						|
}
 | 
						|
 | 
						|
type ExeContextStore interface {
 | 
						|
	GetNodeCtx(key vo.NodeKey) (*Context, bool, error)
 | 
						|
	SetNodeCtx(key vo.NodeKey, value *Context) error
 | 
						|
	GetWorkflowCtx() (*Context, bool, error)
 | 
						|
	SetWorkflowCtx(value *Context) error
 | 
						|
}
 |