fix: context cancel not working during node runner execution (#819)

This commit is contained in:
Zhj
2025-08-21 17:59:01 +08:00
committed by GitHub
parent 09d00c26cb
commit 19c63a1150
6 changed files with 161 additions and 121 deletions

View File

@@ -36,6 +36,7 @@ import (
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
var n int64
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var invokeOutput map[string]any
output, err = r.i(ctx, input, opts...)
for {
err = exec.RunWithContextDone(ctx, func() error {
var invokeErr error
invokeOutput, invokeErr = r.i(ctx, input, opts...)
if invokeErr != nil {
return invokeErr
}
return nil
})
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts .
}
continue
}
return nil, err
}
return invokeOutput, nil
return output, nil
}
}
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
var n int64
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var streamOutput *schema.StreamReader[map[string]any]
for {
err = exec.RunWithContextDone(ctx, func() error {
var streamErr error
streamOutput, streamErr = r.s(ctx, input, opts...)
if streamErr != nil {
return streamErr
}
return nil
})
output, err = r.s(ctx, input, opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
}
return nil, err
}
return streamOutput, nil
return output, nil
}
}
@@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
}
copied := input.Copy(int(r.maxRetry))
var n int64
defer func() {
for i := n + 1; i < r.maxRetry; i++ {
copied[i].Close()
}
}()
var collectOutput map[string]any
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
err = exec.RunWithContextDone(ctx, func() error {
var collectErr error
collectOutput, collectErr = r.c(ctx, copied[n], opts...)
if collectErr != nil {
return collectErr
}
return nil
})
output, err = r.c(ctx, copied[n], opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
}
continue
}
return nil, err
}
return output, nil
return collectOutput, nil
}
}
@@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
}
}()
var transformOutput *schema.StreamReader[map[string]any]
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
output, err = r.t(ctx, copied[n], opts...)
err = exec.RunWithContextDone(ctx, func() error {
var transformErr error
transformOutput, transformErr = r.t(ctx, copied[n], opts...)
if transformErr != nil {
return transformErr
}
return nil
})
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
return nil, err
}
return output, nil
return transformOutput, nil
}
}