293 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			293 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package compose
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"runtime/debug"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	einoCompose "github.com/cloudwego/eino/compose"
 | 
						|
	"github.com/cloudwego/eino/schema"
 | 
						|
 | 
						|
	wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/safego"
 | 
						|
)
 | 
						|
 | 
						|
type WorkflowRunner struct {
 | 
						|
	basic        *entity.WorkflowBasic
 | 
						|
	input        string
 | 
						|
	resumeReq    *entity.ResumeRequest
 | 
						|
	schema       *WorkflowSchema
 | 
						|
	streamWriter *schema.StreamWriter[*entity.Message]
 | 
						|
	config       vo.ExecuteConfig
 | 
						|
 | 
						|
	executeID      int64
 | 
						|
	eventChan      chan *execute.Event
 | 
						|
	interruptEvent *entity.InterruptEvent
 | 
						|
}
 | 
						|
 | 
						|
type workflowRunOptions struct {
 | 
						|
	input              string
 | 
						|
	resumeReq          *entity.ResumeRequest
 | 
						|
	streamWriter       *schema.StreamWriter[*entity.Message]
 | 
						|
	rootTokenCollector *execute.TokenCollector
 | 
						|
}
 | 
						|
 | 
						|
type WorkflowRunnerOption func(*workflowRunOptions)
 | 
						|
 | 
						|
func WithInput(input string) WorkflowRunnerOption {
 | 
						|
	return func(opts *workflowRunOptions) {
 | 
						|
		opts.input = input
 | 
						|
	}
 | 
						|
}
 | 
						|
func WithResumeReq(resumeReq *entity.ResumeRequest) WorkflowRunnerOption {
 | 
						|
	return func(opts *workflowRunOptions) {
 | 
						|
		opts.resumeReq = resumeReq
 | 
						|
	}
 | 
						|
}
 | 
						|
func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOption {
 | 
						|
	return func(opts *workflowRunOptions) {
 | 
						|
		opts.streamWriter = sw
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
 | 
						|
	options := &workflowRunOptions{}
 | 
						|
	for _, opt := range opts {
 | 
						|
		opt(options)
 | 
						|
	}
 | 
						|
 | 
						|
	return &WorkflowRunner{
 | 
						|
		basic:        b,
 | 
						|
		input:        options.input,
 | 
						|
		resumeReq:    options.resumeReq,
 | 
						|
		schema:       sc,
 | 
						|
		streamWriter: options.streamWriter,
 | 
						|
		config:       config,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (r *WorkflowRunner) Prepare(ctx context.Context) (
 | 
						|
	context.Context,
 | 
						|
	int64,
 | 
						|
	[]einoCompose.Option,
 | 
						|
	<-chan *execute.Event,
 | 
						|
	error,
 | 
						|
) {
 | 
						|
	var (
 | 
						|
		err       error
 | 
						|
		executeID int64
 | 
						|
		repo      = wf.GetRepository()
 | 
						|
		resumeReq = r.resumeReq
 | 
						|
		wb        = r.basic
 | 
						|
		sc        = r.schema
 | 
						|
		sw        = r.streamWriter
 | 
						|
		config    = r.config
 | 
						|
	)
 | 
						|
 | 
						|
	if r.resumeReq == nil {
 | 
						|
		executeID, err = repo.GenID(ctx)
 | 
						|
		if err != nil {
 | 
						|
			return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err)
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		executeID = resumeReq.ExecuteID
 | 
						|
	}
 | 
						|
 | 
						|
	eventChan := make(chan *execute.Event)
 | 
						|
 | 
						|
	var (
 | 
						|
		interruptEvent *entity.InterruptEvent
 | 
						|
		found          bool
 | 
						|
	)
 | 
						|
 | 
						|
	if resumeReq != nil {
 | 
						|
		interruptEvent, found, err = repo.GetFirstInterruptEvent(ctx, executeID)
 | 
						|
		if err != nil {
 | 
						|
			return ctx, 0, nil, nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		if !found {
 | 
						|
			return ctx, 0, nil, nil, fmt.Errorf("interrupt event does not exist, id: %d", resumeReq.EventID)
 | 
						|
		}
 | 
						|
 | 
						|
		if interruptEvent.ID != resumeReq.EventID {
 | 
						|
			return ctx, 0, nil, nil, fmt.Errorf("interrupt event id mismatch, expect: %d, actual: %d", resumeReq.EventID, interruptEvent.ID)
 | 
						|
		}
 | 
						|
 | 
						|
	}
 | 
						|
 | 
						|
	r.executeID = executeID
 | 
						|
	r.eventChan = eventChan
 | 
						|
	r.interruptEvent = interruptEvent
 | 
						|
 | 
						|
	ctx, composeOpts, err := r.designateOptions(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return ctx, 0, nil, nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if interruptEvent != nil {
 | 
						|
		var stateOpt einoCompose.Option
 | 
						|
		stateModifier := GenStateModifierByEventType(interruptEvent.EventType,
 | 
						|
			interruptEvent.NodeKey, resumeReq.ResumeData, r.config)
 | 
						|
 | 
						|
		if len(interruptEvent.NodePath) == 1 {
 | 
						|
			// this interrupt event is within the top level workflow
 | 
						|
			stateOpt = einoCompose.WithStateModifier(stateModifier)
 | 
						|
		} else {
 | 
						|
			currentI := len(interruptEvent.NodePath) - 2
 | 
						|
			path := interruptEvent.NodePath[currentI]
 | 
						|
			if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
 | 
						|
				// this interrupt event is within a composite node
 | 
						|
				indexStr := path[len(execute.InterruptEventIndexPrefix):]
 | 
						|
				index, err := strconv.Atoi(indexStr)
 | 
						|
				if err != nil {
 | 
						|
					return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
 | 
						|
				}
 | 
						|
 | 
						|
				currentI--
 | 
						|
				parentNodeKey := interruptEvent.NodePath[currentI]
 | 
						|
				stateOpt = einoCompose.WithLambdaOption(
 | 
						|
					nodes.WithResumeIndex(index, stateModifier)).DesignateNode(parentNodeKey)
 | 
						|
			} else { // this interrupt event is within a sub workflow
 | 
						|
				subWorkflowNodeKey := interruptEvent.NodePath[currentI]
 | 
						|
				stateOpt = einoCompose.WithLambdaOption(
 | 
						|
					nodes.WithResumeIndex(0, stateModifier)).DesignateNode(subWorkflowNodeKey)
 | 
						|
			}
 | 
						|
 | 
						|
			for i := currentI - 1; i >= 0; i-- {
 | 
						|
				path := interruptEvent.NodePath[i]
 | 
						|
				if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
 | 
						|
					indexStr := path[len(execute.InterruptEventIndexPrefix):]
 | 
						|
					index, err := strconv.Atoi(indexStr)
 | 
						|
					if err != nil {
 | 
						|
						return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
 | 
						|
					}
 | 
						|
 | 
						|
					i--
 | 
						|
					parentNodeKey := interruptEvent.NodePath[i]
 | 
						|
					stateOpt = WrapOptWithIndex(stateOpt, vo.NodeKey(parentNodeKey), index)
 | 
						|
				} else {
 | 
						|
					stateOpt = WrapOpt(stateOpt, vo.NodeKey(path))
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		composeOpts = append(composeOpts, stateOpt)
 | 
						|
 | 
						|
		if interruptEvent.EventType == entity.InterruptEventQuestion {
 | 
						|
			modifiedData, err := qa.AppendInterruptData(interruptEvent.InterruptData, resumeReq.ResumeData)
 | 
						|
			if err != nil {
 | 
						|
				return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data: %w", err)
 | 
						|
			}
 | 
						|
			interruptEvent.InterruptData = modifiedData
 | 
						|
			if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
 | 
						|
				return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
 | 
						|
			}
 | 
						|
		} else if interruptEvent.EventType == entity.InterruptEventLLM &&
 | 
						|
			interruptEvent.ToolInterruptEvent.EventType == entity.InterruptEventQuestion {
 | 
						|
			modifiedData, err := qa.AppendInterruptData(interruptEvent.ToolInterruptEvent.InterruptData, resumeReq.ResumeData)
 | 
						|
			if err != nil {
 | 
						|
				return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data for LLM node: %w", err)
 | 
						|
			}
 | 
						|
			interruptEvent.ToolInterruptEvent.InterruptData = modifiedData
 | 
						|
			if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
 | 
						|
				return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		success, currentStatus, err := repo.TryLockWorkflowExecution(ctx, executeID, resumeReq.EventID)
 | 
						|
		if err != nil {
 | 
						|
			return ctx, 0, nil, nil, fmt.Errorf("try lock workflow execution unexpected err: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		if !success {
 | 
						|
			return ctx, 0, nil, nil, fmt.Errorf("workflow execution lock failed, current status is %v, executeID: %d", currentStatus, executeID)
 | 
						|
		}
 | 
						|
 | 
						|
		logs.CtxInfof(ctx, "resuming with eventID: %d, executeID: %d, nodeKey: %s", interruptEvent.ID,
 | 
						|
			executeID, interruptEvent.NodeKey)
 | 
						|
	}
 | 
						|
 | 
						|
	if interruptEvent == nil {
 | 
						|
		var logID string
 | 
						|
		logID, _ = ctx.Value("log-id").(string)
 | 
						|
 | 
						|
		wfExec := &entity.WorkflowExecution{
 | 
						|
			ID:                     executeID,
 | 
						|
			WorkflowID:             wb.ID,
 | 
						|
			Version:                wb.Version,
 | 
						|
			SpaceID:                wb.SpaceID,
 | 
						|
			ExecuteConfig:          config,
 | 
						|
			Status:                 entity.WorkflowRunning,
 | 
						|
			Input:                  ptr.Of(r.input),
 | 
						|
			RootExecutionID:        executeID,
 | 
						|
			NodeCount:              sc.NodeCount(),
 | 
						|
			CurrentResumingEventID: ptr.Of(int64(0)),
 | 
						|
			CommitID:               wb.CommitID,
 | 
						|
			LogID:                  logID,
 | 
						|
		}
 | 
						|
 | 
						|
		if err = repo.CreateWorkflowExecution(ctx, wfExec); err != nil {
 | 
						|
			return ctx, 0, nil, nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	cancelCtx, cancelFn := context.WithCancel(ctx)
 | 
						|
	var timeoutFn context.CancelFunc
 | 
						|
	if s := execute.GetStaticConfig(); s != nil {
 | 
						|
		timeout := ternary.IFElse(config.TaskType == vo.TaskTypeBackground, s.BackgroundRunTimeout, s.ForegroundRunTimeout)
 | 
						|
		if timeout > 0 {
 | 
						|
			cancelCtx, timeoutFn = context.WithTimeout(cancelCtx, timeout)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	cancelCtx = execute.InitExecutedNodesCounter(cancelCtx)
 | 
						|
 | 
						|
	lastEventChan := make(chan *execute.Event, 1)
 | 
						|
	go func() {
 | 
						|
		defer func() {
 | 
						|
			if panicErr := recover(); panicErr != nil {
 | 
						|
				logs.CtxErrorf(ctx, "panic when handling execute event: %v", safego.NewPanicErr(panicErr, debug.Stack()))
 | 
						|
			}
 | 
						|
		}()
 | 
						|
		defer func() {
 | 
						|
			if sw != nil {
 | 
						|
				sw.Close()
 | 
						|
			}
 | 
						|
		}()
 | 
						|
 | 
						|
		// this goroutine should not use the cancelCtx because it needs to be alive to receive workflow cancel events
 | 
						|
		lastEventChan <- execute.HandleExecuteEvent(ctx, executeID, eventChan, cancelFn, timeoutFn,
 | 
						|
			repo, sw, config)
 | 
						|
		close(lastEventChan)
 | 
						|
	}()
 | 
						|
 | 
						|
	return cancelCtx, executeID, composeOpts, lastEventChan, nil
 | 
						|
}
 |