fix: context cancel not working during node runner execution (#819)

This commit is contained in:
Zhj
2025-08-21 17:59:01 +08:00
committed by GitHub
parent 09d00c26cb
commit 19c63a1150
6 changed files with 161 additions and 121 deletions

View File

@@ -36,6 +36,7 @@ import (
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
var n int64
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var invokeOutput map[string]any
output, err = r.i(ctx, input, opts...)
for {
err = exec.RunWithContextDone(ctx, func() error {
var invokeErr error
invokeOutput, invokeErr = r.i(ctx, input, opts...)
if invokeErr != nil {
return invokeErr
}
return nil
})
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts .
}
continue
}
return nil, err
}
return invokeOutput, nil
return output, nil
}
}
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
var n int64
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var streamOutput *schema.StreamReader[map[string]any]
for {
err = exec.RunWithContextDone(ctx, func() error {
var streamErr error
streamOutput, streamErr = r.s(ctx, input, opts...)
if streamErr != nil {
return streamErr
}
return nil
})
output, err = r.s(ctx, input, opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
}
return nil, err
}
return streamOutput, nil
return output, nil
}
}
@@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
}
copied := input.Copy(int(r.maxRetry))
var n int64
defer func() {
for i := n + 1; i < r.maxRetry; i++ {
copied[i].Close()
}
}()
var collectOutput map[string]any
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
err = exec.RunWithContextDone(ctx, func() error {
var collectErr error
collectOutput, collectErr = r.c(ctx, copied[n], opts...)
if collectErr != nil {
return collectErr
}
return nil
})
output, err = r.c(ctx, copied[n], opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
}
continue
}
return nil, err
}
return output, nil
return collectOutput, nil
}
}
@@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
}
}()
var transformOutput *schema.StreamReader[map[string]any]
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
output, err = r.t(ctx, copied[n], opts...)
err = exec.RunWithContextDone(ctx, func() error {
var transformErr error
transformOutput, transformErr = r.t(ctx, copied[n], opts...)
if transformErr != nil {
return transformErr
}
return nil
})
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
return nil, err
}
return output, nil
return transformOutput, nil
}
}

View File

@@ -57,17 +57,25 @@ func (t *TokenCollector) addTokenUsage(usage *model.TokenUsage) {
}
func (t *TokenCollector) wait() *model.TokenUsage {
t.wg.Wait()
t.mu.Lock()
defer t.mu.Unlock()
t.wg.Wait()
usage := &model.TokenUsage{
PromptTokens: t.Usage.PromptTokens,
CompletionTokens: t.Usage.CompletionTokens,
TotalTokens: t.Usage.TotalTokens,
}
t.mu.Unlock()
return usage
}
func (t *TokenCollector) add(i int) {
t.mu.Lock()
defer t.mu.Unlock()
t.wg.Add(i)
return
}
func getTokenCollector(ctx context.Context) *TokenCollector {
c := GetExeCtx(ctx)
if c == nil {
@@ -83,7 +91,8 @@ func GetTokenCallbackHandler() callbacks.Handler {
if c == nil {
return ctx
}
c.wg.Add(1)
c.add(1)
//c.wg.Add(1)
return ctx
},
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
@@ -122,12 +131,16 @@ func GetTokenCallbackHandler() callbacks.Handler {
if chunk.TokenUsage == nil {
continue
}
// 在goroutine内部累加避免并发访问
newC.PromptTokens += chunk.TokenUsage.PromptTokens
newC.CompletionTokens += chunk.TokenUsage.CompletionTokens
newC.TotalTokens += chunk.TokenUsage.TotalTokens
}
c.addTokenUsage(newC)
// 只在最后调用一次addTokenUsage,减少锁竞争
if newC.TotalTokens > 0 {
c.addTokenUsage(newC)
}
})
return ctx
},