fix: context cancel not working during node runner execution (#819)

This commit is contained in:
Zhj 2025-08-21 17:59:01 +08:00 committed by GitHub
parent 09d00c26cb
commit 19c63a1150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 161 additions and 121 deletions

View File

@ -106,10 +106,6 @@ import (
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
var (
publishPatcher *mockey.Mocker
)
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler()) callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
service.RegisterAllNodeAdaptors() service.RegisterAllNodeAdaptors()
@ -131,6 +127,7 @@ type wfTestRunner struct {
database *databasemock.MockDatabase database *databasemock.MockDatabase
pluginSrv *pluginmock.MockPluginService pluginSrv *pluginmock.MockPluginService
internalModel *testutil.UTChatModel internalModel *testutil.UTChatModel
publishPatcher *mockey.Mocker
ctx context.Context ctx context.Context
closeFn func() closeFn func()
} }
@ -256,7 +253,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel) workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build() mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build() mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() publishPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
mockCU := mockCrossUser.NewMockUser(ctrl) mockCU := mockCrossUser.NewMockUser(ctrl)
mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{ mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{
@ -305,9 +302,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
}, nil).Build() }, nil).Build()
f := func() { f := func() {
if publishPatcher != nil {
publishPatcher.UnPatch() publishPatcher.UnPatch()
}
m.UnPatch() m.UnPatch()
m1.UnPatch() m1.UnPatch()
m2.UnPatch() m2.UnPatch()
@ -336,6 +331,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
ctx: context.Background(), ctx: context.Background(),
closeFn: f, closeFn: f,
pluginSrv: mockPluginSrv, pluginSrv: mockPluginSrv,
publishPatcher: publishPatcher,
} }
} }
@ -4147,14 +4143,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
} }
if publishPatcher != nil { defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
publishPatcher.UnPatch()
}
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
defer func() {
localPatcher.UnPatch()
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
}()
appID := "7513788954458456064" appID := "7513788954458456064"
appIDInt64, _ := strconv.ParseInt(appID, 10, 64) appIDInt64, _ := strconv.ParseInt(appID, 10, 64)
@ -4265,14 +4254,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
return nil return nil
} }
if publishPatcher != nil {
publishPatcher.UnPatch() defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
}
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
defer func() {
localPatcher.UnPatch()
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
}()
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{ defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{
TargetKnowledgeID: 100100, TargetKnowledgeID: 100100,
@ -4313,6 +4296,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
func TestMoveWorkflowAppToLibrary(t *testing.T) { func TestMoveWorkflowAppToLibrary(t *testing.T) {
mockey.PatchConvey("test move workflow", t, func() { mockey.PatchConvey("test move workflow", t, func() {
r := newWfTestRunner(t) r := newWfTestRunner(t)
r.publishPatcher.UnPatch()
defer r.closeFn() defer r.closeFn()
vars := map[string]*vo.TypeInfo{ vars := map[string]*vo.TypeInfo{
"app_v1": { "app_v1": {
@ -4354,14 +4338,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
} }
if publishPatcher != nil { defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
publishPatcher.UnPatch()
}
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
defer func() {
localPatcher.UnPatch()
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
}()
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch() defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch()
defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch() defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch()
@ -4479,6 +4456,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
func TestDuplicateWorkflowsByAppID(t *testing.T) { func TestDuplicateWorkflowsByAppID(t *testing.T) {
mockey.PatchConvey("test duplicate work", t, func() { mockey.PatchConvey("test duplicate work", t, func() {
r := newWfTestRunner(t) r := newWfTestRunner(t)
r.publishPatcher.UnPatch()
defer r.closeFn() defer r.closeFn()
vars := map[string]*vo.TypeInfo{ vars := map[string]*vo.TypeInfo{
@ -4516,14 +4494,7 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
} }
if publishPatcher != nil { defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
publishPatcher.UnPatch()
}
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
defer func() {
localPatcher.UnPatch()
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
}()
appIDInt64 := int64(7513788954458456064) appIDInt64 := int64(7513788954458456064)

View File

@ -36,6 +36,7 @@ import (
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) { func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
var n int64 var n int64
for { var invokeOutput map[string]any
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
output, err = r.i(ctx, input, opts...) for {
err = exec.RunWithContextDone(ctx, func() error {
var invokeErr error
invokeOutput, invokeErr = r.i(ctx, input, opts...)
if invokeErr != nil {
return invokeErr
}
return nil
})
if err != nil { if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true r.interrupted = true
return nil, err return nil, err
} }
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n { if r.maxRetry > n {
n++ n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts .
} }
continue continue
} }
return nil, err return nil, err
} }
return invokeOutput, nil
return output, nil
} }
} }
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) { func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
var n int64 var n int64
for { var streamOutput *schema.StreamReader[map[string]any]
select {
case <-ctx.Done(): for {
return nil, ctx.Err() err = exec.RunWithContextDone(ctx, func() error {
default: var streamErr error
} streamOutput, streamErr = r.s(ctx, input, opts...)
if streamErr != nil {
return streamErr
}
return nil
})
output, err = r.s(ctx, input, opts...)
if err != nil { if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true r.interrupted = true
return nil, err return nil, err
} }
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n { if r.maxRetry > n {
n++ n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
} }
return nil, err return nil, err
} }
return streamOutput, nil
return output, nil
} }
} }
@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
} }
copied := input.Copy(int(r.maxRetry)) copied := input.Copy(int(r.maxRetry))
var n int64 var n int64
defer func() { defer func() {
for i := n + 1; i < r.maxRetry; i++ { for i := n + 1; i < r.maxRetry; i++ {
copied[i].Close() copied[i].Close()
} }
}() }()
var collectOutput map[string]any
for { for {
select { err = exec.RunWithContextDone(ctx, func() error {
case <-ctx.Done(): var collectErr error
return nil, ctx.Err() collectOutput, collectErr = r.c(ctx, copied[n], opts...)
default: if collectErr != nil {
return collectErr
} }
return nil
})
output, err = r.c(ctx, copied[n], opts...)
if err != nil { if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true r.interrupted = true
return nil, err return nil, err
} }
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n { if r.maxRetry > n {
n++ n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
} }
continue continue
} }
return nil, err return nil, err
} }
return collectOutput, nil
return output, nil
} }
} }
@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
} }
}() }()
var transformOutput *schema.StreamReader[map[string]any]
for { for {
select { err = exec.RunWithContextDone(ctx, func() error {
case <-ctx.Done(): var transformErr error
return nil, ctx.Err() transformOutput, transformErr = r.t(ctx, copied[n], opts...)
default: if transformErr != nil {
return transformErr
} }
return nil
output, err = r.t(ctx, copied[n], opts...) })
if err != nil { if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry if _, ok := compose.IsInterruptRerunError(err); ok {
r.interrupted = true r.interrupted = true
return nil, err return nil, err
} }
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n { if r.maxRetry > n {
n++ n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
return nil, err return nil, err
} }
return output, nil return transformOutput, nil
} }
} }

View File

@ -57,17 +57,25 @@ func (t *TokenCollector) addTokenUsage(usage *model.TokenUsage) {
} }
func (t *TokenCollector) wait() *model.TokenUsage { func (t *TokenCollector) wait() *model.TokenUsage {
t.wg.Wait()
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock()
t.wg.Wait()
usage := &model.TokenUsage{ usage := &model.TokenUsage{
PromptTokens: t.Usage.PromptTokens, PromptTokens: t.Usage.PromptTokens,
CompletionTokens: t.Usage.CompletionTokens, CompletionTokens: t.Usage.CompletionTokens,
TotalTokens: t.Usage.TotalTokens, TotalTokens: t.Usage.TotalTokens,
} }
t.mu.Unlock()
return usage return usage
} }
func (t *TokenCollector) add(i int) {
t.mu.Lock()
defer t.mu.Unlock()
t.wg.Add(i)
return
}
func getTokenCollector(ctx context.Context) *TokenCollector { func getTokenCollector(ctx context.Context) *TokenCollector {
c := GetExeCtx(ctx) c := GetExeCtx(ctx)
if c == nil { if c == nil {
@ -83,7 +91,8 @@ func GetTokenCallbackHandler() callbacks.Handler {
if c == nil { if c == nil {
return ctx return ctx
} }
c.wg.Add(1) c.add(1)
//c.wg.Add(1)
return ctx return ctx
}, },
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
@ -122,12 +131,16 @@ func GetTokenCallbackHandler() callbacks.Handler {
if chunk.TokenUsage == nil { if chunk.TokenUsage == nil {
continue continue
} }
// 在goroutine内部累加避免并发访问
newC.PromptTokens += chunk.TokenUsage.PromptTokens newC.PromptTokens += chunk.TokenUsage.PromptTokens
newC.CompletionTokens += chunk.TokenUsage.CompletionTokens newC.CompletionTokens += chunk.TokenUsage.CompletionTokens
newC.TotalTokens += chunk.TokenUsage.TotalTokens newC.TotalTokens += chunk.TokenUsage.TotalTokens
} }
// 只在最后调用一次addTokenUsage减少锁竞争
if newC.TotalTokens > 0 {
c.addTokenUsage(newC) c.addTokenUsage(newC)
}
}) })
return ctx return ctx
}, },

View File

@ -56,6 +56,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09 github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09
github.com/cloudwego/eino-ext/components/model/gemini v0.1.2 github.com/cloudwego/eino-ext/components/model/gemini v0.1.2
@ -85,7 +86,6 @@ require (
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.37 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.37 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.5 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 // indirect
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e // indirect
github.com/cloudwego/gopkg v0.1.4 // indirect github.com/cloudwego/gopkg v0.1.4 // indirect
github.com/evanphx/json-patch v4.12.0+incompatible // indirect github.com/evanphx/json-patch v4.12.0+incompatible // indirect
github.com/extrame/ole2 v0.0.0-20160812065207-d69429661ad7 // indirect github.com/extrame/ole2 v0.0.0-20160812065207-d69429661ad7 // indirect

View File

@ -2614,8 +2614,6 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genai v1.13.0 h1:LRhwx5PU+bXhfnXyPEHu2kt9yc+MpvuYbajxSorOJjg=
google.golang.org/genai v1.13.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M=
google.golang.org/genai v1.18.0 h1:fTmK7y30CO0CL8xRyyFSjTkd1MNbYUeFUehvDyU/2gQ= google.golang.org/genai v1.18.0 h1:fTmK7y30CO0CL8xRyyFSjTkd1MNbYUeFUehvDyU/2gQ=
google.golang.org/genai v1.18.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= google.golang.org/genai v1.18.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M=
google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=

View File

@ -0,0 +1,44 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package execute
import (
"context"
"fmt"
"runtime/debug"
)
func RunWithContextDone(ctx context.Context, fn func() error) error {
errChan := make(chan error, 1)
go func() {
defer func() {
if err := recover(); err != nil {
errChan <- fmt.Errorf("exec func panic, %v \n %s", err, debug.Stack())
}
close(errChan)
}()
err := fn()
errChan <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
return err
}
}