fix: workflow tool closes stream writer correctly (#1839)

This commit is contained in:
shentongmartin
2025-08-27 16:29:42 +08:00
committed by GitHub
parent 263a75b1c0
commit 5562800958
19 changed files with 742 additions and 620 deletions

View File

@@ -23,7 +23,6 @@ import (
"strconv"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
@@ -47,7 +46,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
workflowSC = r.schema
eventChan = r.eventChan
resumedEvent = r.interruptEvent
sw = r.streamWriter
sw = r.container
)
if wb.AppID != nil && exeCfg.AppID == nil {
@@ -148,7 +147,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
var (
resumeEvent = r.interruptEvent
eventChan = r.eventChan
sw = r.streamWriter
container = r.container
)
subHandler := execute.NewSubWorkflowHandler(
parentHandler,
@@ -186,7 +185,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
opts = append(opts, WrapOpt(subO, ns.Key))
}
} else if subNS.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container)
if err != nil {
return nil, err
}
@@ -209,7 +208,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
}
} else if subNS.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container)
if err != nil {
return nil, err
}
@@ -224,7 +223,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
}
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
sw *schema.StreamWriter[*entity.Message]) (
container *execute.StreamContainer) (
opts []einoCompose.Option, err error) {
// this is a LLM node.
// check if it has any tools, if no tools, then no callback options needed
@@ -280,6 +279,12 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
if container != nil {
toolMsgOpt := llm.WithToolWorkflowStreamContainer(container)
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}
if fcParams.PluginFCParam != nil {
for _, p := range fcParams.PluginFCParam.PluginList {
@@ -321,11 +326,5 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh
}
}
if sw != nil {
toolMsgOpt := llm.WithToolWorkflowMessageWriter(sw)
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
return opts, nil
}

View File

@@ -38,15 +38,17 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type WorkflowRunner struct {
basic *entity.WorkflowBasic
input string
resumeReq *entity.ResumeRequest
schema *schema2.WorkflowSchema
streamWriter *schema.StreamWriter[*entity.Message]
config model.ExecuteConfig
basic *entity.WorkflowBasic
input string
resumeReq *entity.ResumeRequest
schema *schema2.WorkflowSchema
sw *schema.StreamWriter[*entity.Message]
container *execute.StreamContainer
config model.ExecuteConfig
executeID int64
eventChan chan *execute.Event
@@ -84,13 +86,19 @@ func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, conf
opt(options)
}
var container *execute.StreamContainer
if options.streamWriter != nil {
container = execute.NewStreamContainer(options.streamWriter)
}
return &WorkflowRunner{
basic: b,
input: options.input,
resumeReq: options.resumeReq,
schema: sc,
streamWriter: options.streamWriter,
config: config,
basic: b,
input: options.input,
resumeReq: options.resumeReq,
schema: sc,
sw: options.streamWriter,
container: container,
config: config,
}
}
@@ -108,14 +116,16 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
resumeReq = r.resumeReq
wb = r.basic
sc = r.schema
sw = r.streamWriter
sw = r.sw
container = r.container
config = r.config
)
if r.resumeReq == nil {
executeID, err = repo.GenID(ctx)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err)
return ctx, 0, nil, nil, vo.WrapError(errno.ErrIDGenError,
fmt.Errorf("failed to generate workflow execute ID: %w", err))
}
} else {
executeID = resumeReq.ExecuteID
@@ -148,6 +158,15 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
r.eventChan = eventChan
r.interruptEvent = interruptEvent
if container != nil {
go container.PipeAll()
defer func() {
if err != nil {
container.Done()
}
}()
}
ctx, composeOpts, err := r.designateOptions(ctx)
if err != nil {
return ctx, 0, nil, nil, err
@@ -277,8 +296,8 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
}
}()
defer func() {
if sw != nil {
sw.Close()
if container != nil {
container.Done()
}
}()

View File

@@ -33,17 +33,22 @@ import (
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
const answerKey = "output"
type invokableWorkflow struct {
workflowTool
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
}
type workflowTool struct {
info *schema.ToolInfo
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *schema2.WorkflowSchema
repo wf.Repository
terminatePlan vo.TerminatePlan
}
func NewInvokableWorkflow(info *schema.ToolInfo,
@@ -54,12 +59,14 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &invokableWorkflow{
info: info,
invoke: invoke,
terminatePlan: terminatePlan,
wfEntity: wfEntity,
sc: sc,
repo: repo,
workflowTool: workflowTool{
info: info,
wfEntity: wfEntity,
sc: sc,
repo: repo,
terminatePlan: terminatePlan,
},
invoke: invoke,
}
}
@@ -77,6 +84,52 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e
}
}
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) (
cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) {
cfg := execute.GetExecuteConfig(opts...)
var runOpts []WorkflowRunnerOption
if rInfo != nil && !rInfo.Resumed {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if container := execute.GetParentStreamContainer(opts...); container != nil {
sr, sw := schema.Pipe[*entity.Message](10)
container.AddChild(sr)
runOpts = append(runOpts, WithStreamWriter(sw))
}
var ws *nodes.ConversionWarnings
if (rInfo == nil || rInfo.Resumed) && len(wt.wfEntity.InputParams) > 0 {
if err = sonic.UnmarshalString(argumentsInJSON, &input); err != nil {
err = vo.WrapError(errno.ErrSerializationDeserializationFail, err)
return
}
var entryNode *schema2.NodeSchema
for _, node := range wt.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
input, ws, err = nodes.ConvertInputs(ctx, input, entryNode.OutputTypes)
if err != nil {
return
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
return
}
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
rInfo, allIEs := execute.GetResumeRequest(opts...)
var (
@@ -97,52 +150,9 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
return "", einoCompose.InterruptAndRerun
}
cfg := execute.GetExecuteConfig(opts...)
defer resumeOnce(rInfo, callID, allIEs)
var runOpts []WorkflowRunnerOption
if rInfo != nil && !rInfo.Resumed {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
runOpts = append(runOpts, WithStreamWriter(sw))
}
var (
cancelCtx context.Context
executeID int64
callOpts []einoCompose.Option
in map[string]any
err error
ws *nodes.ConversionWarnings
)
if rInfo == nil && len(i.wfEntity.InputParams) > 0 {
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
return "", err
}
var entryNode *schema2.NodeSchema
for _, node := range i.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
if err != nil {
return "", err
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(i.wfEntity.GetBasic(), i.sc, cfg, runOpts...).Prepare(ctx)
cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
if err != nil {
return "", err
}
@@ -198,12 +208,8 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
}
type streamableWorkflow struct {
info *schema.ToolInfo
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *schema2.WorkflowSchema
repo wf.Repository
workflowTool
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
}
func NewStreamableWorkflow(info *schema.ToolInfo,
@@ -214,12 +220,14 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &streamableWorkflow{
info: info,
stream: stream,
terminatePlan: terminatePlan,
wfEntity: wfEntity,
sc: sc,
repo: repo,
workflowTool: workflowTool{
info: info,
wfEntity: wfEntity,
sc: sc,
repo: repo,
terminatePlan: terminatePlan,
},
stream: stream,
}
}
@@ -247,52 +255,9 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
return nil, einoCompose.InterruptAndRerun
}
cfg := execute.GetExecuteConfig(opts...)
defer resumeOnce(rInfo, callID, allIEs)
var runOpts []WorkflowRunnerOption
if rInfo != nil && !rInfo.Resumed {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
runOpts = append(runOpts, WithStreamWriter(sw))
}
var (
cancelCtx context.Context
executeID int64
callOpts []einoCompose.Option
in map[string]any
err error
ws *nodes.ConversionWarnings
)
if rInfo == nil && len(s.wfEntity.InputParams) > 0 {
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
return nil, err
}
var entryNode *schema2.NodeSchema
for _, node := range s.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
if err != nil {
return nil, err
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(s.wfEntity.GetBasic(), s.sc, cfg, runOpts...).Prepare(ctx)
cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
if err != nil {
return nil, err
}