fix: context cancel not working during node runner execution (#819)
This commit is contained in:
@@ -36,6 +36,7 @@ import (
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
@@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
|
||||
|
||||
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
var n int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
var invokeOutput map[string]any
|
||||
|
||||
output, err = r.i(ctx, input, opts...)
|
||||
for {
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var invokeErr error
|
||||
invokeOutput, invokeErr = r.i(ctx, input, opts...)
|
||||
if invokeErr != nil {
|
||||
return invokeErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts .
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
return invokeOutput, nil
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
var n int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
var streamOutput *schema.StreamReader[map[string]any]
|
||||
|
||||
for {
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var streamErr error
|
||||
streamOutput, streamErr = r.s(ctx, input, opts...)
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
output, err = r.s(ctx, input, opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return streamOutput, nil
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
|
||||
}
|
||||
|
||||
copied := input.Copy(int(r.maxRetry))
|
||||
|
||||
var n int64
|
||||
|
||||
defer func() {
|
||||
for i := n + 1; i < r.maxRetry; i++ {
|
||||
copied[i].Close()
|
||||
}
|
||||
}()
|
||||
|
||||
var collectOutput map[string]any
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var collectErr error
|
||||
collectOutput, collectErr = r.c(ctx, copied[n], opts...)
|
||||
if collectErr != nil {
|
||||
return collectErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
output, err = r.c(ctx, copied[n], opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return collectOutput, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||
}
|
||||
}()
|
||||
|
||||
var transformOutput *schema.StreamReader[map[string]any]
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
output, err = r.t(ctx, copied[n], opts...)
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var transformErr error
|
||||
transformOutput, transformErr = r.t(ctx, copied[n], opts...)
|
||||
if transformErr != nil {
|
||||
return transformErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return transformOutput, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user