fix: context cancel not working during node runner execution (#819)
This commit is contained in:
@@ -36,6 +36,7 @@ import (
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
@@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
|
||||
|
||||
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
var n int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
var invokeOutput map[string]any
|
||||
|
||||
output, err = r.i(ctx, input, opts...)
|
||||
for {
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var invokeErr error
|
||||
invokeOutput, invokeErr = r.i(ctx, input, opts...)
|
||||
if invokeErr != nil {
|
||||
return invokeErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts .
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
return invokeOutput, nil
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
var n int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
var streamOutput *schema.StreamReader[map[string]any]
|
||||
|
||||
for {
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var streamErr error
|
||||
streamOutput, streamErr = r.s(ctx, input, opts...)
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
output, err = r.s(ctx, input, opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return streamOutput, nil
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
|
||||
}
|
||||
|
||||
copied := input.Copy(int(r.maxRetry))
|
||||
|
||||
var n int64
|
||||
|
||||
defer func() {
|
||||
for i := n + 1; i < r.maxRetry; i++ {
|
||||
copied[i].Close()
|
||||
}
|
||||
}()
|
||||
|
||||
var collectOutput map[string]any
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var collectErr error
|
||||
collectOutput, collectErr = r.c(ctx, copied[n], opts...)
|
||||
if collectErr != nil {
|
||||
return collectErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
output, err = r.c(ctx, copied[n], opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return collectOutput, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||
}
|
||||
}()
|
||||
|
||||
var transformOutput *schema.StreamReader[map[string]any]
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
output, err = r.t(ctx, copied[n], opts...)
|
||||
err = exec.RunWithContextDone(ctx, func() error {
|
||||
var transformErr error
|
||||
transformOutput, transformErr = r.t(ctx, copied[n], opts...)
|
||||
if transformErr != nil {
|
||||
return transformErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
@@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return transformOutput, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,17 +57,25 @@ func (t *TokenCollector) addTokenUsage(usage *model.TokenUsage) {
|
||||
}
|
||||
|
||||
func (t *TokenCollector) wait() *model.TokenUsage {
|
||||
t.wg.Wait()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.wg.Wait()
|
||||
usage := &model.TokenUsage{
|
||||
PromptTokens: t.Usage.PromptTokens,
|
||||
CompletionTokens: t.Usage.CompletionTokens,
|
||||
TotalTokens: t.Usage.TotalTokens,
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
func (t *TokenCollector) add(i int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.wg.Add(i)
|
||||
return
|
||||
}
|
||||
|
||||
func getTokenCollector(ctx context.Context) *TokenCollector {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
@@ -83,7 +91,8 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
if c == nil {
|
||||
return ctx
|
||||
}
|
||||
c.wg.Add(1)
|
||||
c.add(1)
|
||||
//c.wg.Add(1)
|
||||
return ctx
|
||||
},
|
||||
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
|
||||
@@ -122,12 +131,16 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
if chunk.TokenUsage == nil {
|
||||
continue
|
||||
}
|
||||
// 在goroutine内部累加,避免并发访问
|
||||
newC.PromptTokens += chunk.TokenUsage.PromptTokens
|
||||
newC.CompletionTokens += chunk.TokenUsage.CompletionTokens
|
||||
newC.TotalTokens += chunk.TokenUsage.TotalTokens
|
||||
}
|
||||
|
||||
c.addTokenUsage(newC)
|
||||
// 只在最后调用一次addTokenUsage,减少锁竞争
|
||||
if newC.TotalTokens > 0 {
|
||||
c.addTokenUsage(newC)
|
||||
}
|
||||
})
|
||||
return ctx
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user