294 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			294 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package compose
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"runtime/debug"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	einoCompose "github.com/cloudwego/eino/compose"
 | |
| 	"github.com/cloudwego/eino/schema"
 | |
| 
 | |
| 	wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
 | |
| 	schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/safego"
 | |
| )
 | |
| 
 | |
| type WorkflowRunner struct {
 | |
| 	basic        *entity.WorkflowBasic
 | |
| 	input        string
 | |
| 	resumeReq    *entity.ResumeRequest
 | |
| 	schema       *schema2.WorkflowSchema
 | |
| 	streamWriter *schema.StreamWriter[*entity.Message]
 | |
| 	config       vo.ExecuteConfig
 | |
| 
 | |
| 	executeID      int64
 | |
| 	eventChan      chan *execute.Event
 | |
| 	interruptEvent *entity.InterruptEvent
 | |
| }
 | |
| 
 | |
| type workflowRunOptions struct {
 | |
| 	input              string
 | |
| 	resumeReq          *entity.ResumeRequest
 | |
| 	streamWriter       *schema.StreamWriter[*entity.Message]
 | |
| 	rootTokenCollector *execute.TokenCollector
 | |
| }
 | |
| 
 | |
| type WorkflowRunnerOption func(*workflowRunOptions)
 | |
| 
 | |
| func WithInput(input string) WorkflowRunnerOption {
 | |
| 	return func(opts *workflowRunOptions) {
 | |
| 		opts.input = input
 | |
| 	}
 | |
| }
 | |
| func WithResumeReq(resumeReq *entity.ResumeRequest) WorkflowRunnerOption {
 | |
| 	return func(opts *workflowRunOptions) {
 | |
| 		opts.resumeReq = resumeReq
 | |
| 	}
 | |
| }
 | |
| func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOption {
 | |
| 	return func(opts *workflowRunOptions) {
 | |
| 		opts.streamWriter = sw
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
 | |
| 	options := &workflowRunOptions{}
 | |
| 	for _, opt := range opts {
 | |
| 		opt(options)
 | |
| 	}
 | |
| 
 | |
| 	return &WorkflowRunner{
 | |
| 		basic:        b,
 | |
| 		input:        options.input,
 | |
| 		resumeReq:    options.resumeReq,
 | |
| 		schema:       sc,
 | |
| 		streamWriter: options.streamWriter,
 | |
| 		config:       config,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (r *WorkflowRunner) Prepare(ctx context.Context) (
 | |
| 	context.Context,
 | |
| 	int64,
 | |
| 	[]einoCompose.Option,
 | |
| 	<-chan *execute.Event,
 | |
| 	error,
 | |
| ) {
 | |
| 	var (
 | |
| 		err       error
 | |
| 		executeID int64
 | |
| 		repo      = wf.GetRepository()
 | |
| 		resumeReq = r.resumeReq
 | |
| 		wb        = r.basic
 | |
| 		sc        = r.schema
 | |
| 		sw        = r.streamWriter
 | |
| 		config    = r.config
 | |
| 	)
 | |
| 
 | |
| 	if r.resumeReq == nil {
 | |
| 		executeID, err = repo.GenID(ctx)
 | |
| 		if err != nil {
 | |
| 			return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err)
 | |
| 		}
 | |
| 	} else {
 | |
| 		executeID = resumeReq.ExecuteID
 | |
| 	}
 | |
| 
 | |
| 	eventChan := make(chan *execute.Event)
 | |
| 
 | |
| 	var (
 | |
| 		interruptEvent *entity.InterruptEvent
 | |
| 		found          bool
 | |
| 	)
 | |
| 
 | |
| 	if resumeReq != nil {
 | |
| 		interruptEvent, found, err = repo.GetFirstInterruptEvent(ctx, executeID)
 | |
| 		if err != nil {
 | |
| 			return ctx, 0, nil, nil, err
 | |
| 		}
 | |
| 
 | |
| 		if !found {
 | |
| 			return ctx, 0, nil, nil, fmt.Errorf("interrupt event does not exist, id: %d", resumeReq.EventID)
 | |
| 		}
 | |
| 
 | |
| 		if interruptEvent.ID != resumeReq.EventID {
 | |
| 			return ctx, 0, nil, nil, fmt.Errorf("interrupt event id mismatch, expect: %d, actual: %d", resumeReq.EventID, interruptEvent.ID)
 | |
| 		}
 | |
| 
 | |
| 	}
 | |
| 
 | |
| 	r.executeID = executeID
 | |
| 	r.eventChan = eventChan
 | |
| 	r.interruptEvent = interruptEvent
 | |
| 
 | |
| 	ctx, composeOpts, err := r.designateOptions(ctx)
 | |
| 	if err != nil {
 | |
| 		return ctx, 0, nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	if interruptEvent != nil {
 | |
| 		var stateOpt einoCompose.Option
 | |
| 		stateModifier := GenStateModifierByEventType(interruptEvent.EventType,
 | |
| 			interruptEvent.NodeKey, resumeReq.ResumeData, r.config)
 | |
| 
 | |
| 		if len(interruptEvent.NodePath) == 1 {
 | |
| 			// this interrupt event is within the top level workflow
 | |
| 			stateOpt = einoCompose.WithStateModifier(stateModifier)
 | |
| 		} else {
 | |
| 			currentI := len(interruptEvent.NodePath) - 2
 | |
| 			path := interruptEvent.NodePath[currentI]
 | |
| 			if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
 | |
| 				// this interrupt event is within a composite node
 | |
| 				indexStr := path[len(execute.InterruptEventIndexPrefix):]
 | |
| 				index, err := strconv.Atoi(indexStr)
 | |
| 				if err != nil {
 | |
| 					return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
 | |
| 				}
 | |
| 
 | |
| 				currentI--
 | |
| 				parentNodeKey := interruptEvent.NodePath[currentI]
 | |
| 				stateOpt = einoCompose.WithLambdaOption(
 | |
| 					nodes.WithResumeIndex(index, stateModifier)).DesignateNode(parentNodeKey)
 | |
| 			} else { // this interrupt event is within a sub workflow
 | |
| 				subWorkflowNodeKey := interruptEvent.NodePath[currentI]
 | |
| 				stateOpt = einoCompose.WithLambdaOption(
 | |
| 					nodes.WithResumeIndex(0, stateModifier)).DesignateNode(subWorkflowNodeKey)
 | |
| 			}
 | |
| 
 | |
| 			for i := currentI - 1; i >= 0; i-- {
 | |
| 				path := interruptEvent.NodePath[i]
 | |
| 				if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
 | |
| 					indexStr := path[len(execute.InterruptEventIndexPrefix):]
 | |
| 					index, err := strconv.Atoi(indexStr)
 | |
| 					if err != nil {
 | |
| 						return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
 | |
| 					}
 | |
| 
 | |
| 					i--
 | |
| 					parentNodeKey := interruptEvent.NodePath[i]
 | |
| 					stateOpt = WrapOptWithIndex(stateOpt, vo.NodeKey(parentNodeKey), index)
 | |
| 				} else {
 | |
| 					stateOpt = WrapOpt(stateOpt, vo.NodeKey(path))
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		composeOpts = append(composeOpts, stateOpt)
 | |
| 
 | |
| 		if interruptEvent.EventType == entity.InterruptEventQuestion {
 | |
| 			modifiedData, err := qa.AppendInterruptData(interruptEvent.InterruptData, resumeReq.ResumeData)
 | |
| 			if err != nil {
 | |
| 				return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data: %w", err)
 | |
| 			}
 | |
| 			interruptEvent.InterruptData = modifiedData
 | |
| 			if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
 | |
| 				return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
 | |
| 			}
 | |
| 		} else if interruptEvent.EventType == entity.InterruptEventLLM &&
 | |
| 			interruptEvent.ToolInterruptEvent.EventType == entity.InterruptEventQuestion {
 | |
| 			modifiedData, err := qa.AppendInterruptData(interruptEvent.ToolInterruptEvent.InterruptData, resumeReq.ResumeData)
 | |
| 			if err != nil {
 | |
| 				return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data for LLM node: %w", err)
 | |
| 			}
 | |
| 			interruptEvent.ToolInterruptEvent.InterruptData = modifiedData
 | |
| 			if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
 | |
| 				return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		success, currentStatus, err := repo.TryLockWorkflowExecution(ctx, executeID, resumeReq.EventID)
 | |
| 		if err != nil {
 | |
| 			return ctx, 0, nil, nil, fmt.Errorf("try lock workflow execution unexpected err: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		if !success {
 | |
| 			return ctx, 0, nil, nil, fmt.Errorf("workflow execution lock failed, current status is %v, executeID: %d", currentStatus, executeID)
 | |
| 		}
 | |
| 
 | |
| 		logs.CtxInfof(ctx, "resuming with eventID: %d, executeID: %d, nodeKey: %s", interruptEvent.ID,
 | |
| 			executeID, interruptEvent.NodeKey)
 | |
| 	}
 | |
| 
 | |
| 	if interruptEvent == nil {
 | |
| 		var logID string
 | |
| 		logID, _ = ctx.Value("log-id").(string)
 | |
| 
 | |
| 		wfExec := &entity.WorkflowExecution{
 | |
| 			ID:                     executeID,
 | |
| 			WorkflowID:             wb.ID,
 | |
| 			Version:                wb.Version,
 | |
| 			SpaceID:                wb.SpaceID,
 | |
| 			ExecuteConfig:          config,
 | |
| 			Status:                 entity.WorkflowRunning,
 | |
| 			Input:                  ptr.Of(r.input),
 | |
| 			RootExecutionID:        executeID,
 | |
| 			NodeCount:              sc.NodeCount(),
 | |
| 			CurrentResumingEventID: ptr.Of(int64(0)),
 | |
| 			CommitID:               wb.CommitID,
 | |
| 			LogID:                  logID,
 | |
| 		}
 | |
| 
 | |
| 		if err = repo.CreateWorkflowExecution(ctx, wfExec); err != nil {
 | |
| 			return ctx, 0, nil, nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	cancelCtx, cancelFn := context.WithCancel(ctx)
 | |
| 	var timeoutFn context.CancelFunc
 | |
| 	if s := execute.GetStaticConfig(); s != nil {
 | |
| 		timeout := ternary.IFElse(config.TaskType == vo.TaskTypeBackground, s.BackgroundRunTimeout, s.ForegroundRunTimeout)
 | |
| 		if timeout > 0 {
 | |
| 			cancelCtx, timeoutFn = context.WithTimeout(cancelCtx, timeout)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	cancelCtx = execute.InitExecutedNodesCounter(cancelCtx)
 | |
| 
 | |
| 	lastEventChan := make(chan *execute.Event, 1)
 | |
| 	go func() {
 | |
| 		defer func() {
 | |
| 			if panicErr := recover(); panicErr != nil {
 | |
| 				logs.CtxErrorf(ctx, "panic when handling execute event: %v", safego.NewPanicErr(panicErr, debug.Stack()))
 | |
| 			}
 | |
| 		}()
 | |
| 		defer func() {
 | |
| 			if sw != nil {
 | |
| 				sw.Close()
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		// this goroutine should not use the cancelCtx because it needs to be alive to receive workflow cancel events
 | |
| 		lastEventChan <- execute.HandleExecuteEvent(ctx, executeID, eventChan, cancelFn, timeoutFn,
 | |
| 			repo, sw, config)
 | |
| 		close(lastEventChan)
 | |
| 	}()
 | |
| 
 | |
| 	return cancelCtx, executeID, composeOpts, lastEventChan, nil
 | |
| }
 |