diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index e52fd1c9..adb87485 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -106,10 +106,6 @@ import ( "github.com/coze-dev/coze-studio/backend/types/errno" ) -var ( - publishPatcher *mockey.Mocker -) - func TestMain(m *testing.M) { callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler()) service.RegisterAllNodeAdaptors() @@ -117,22 +113,23 @@ func TestMain(m *testing.M) { } type wfTestRunner struct { - t *testing.T - h *server.Hertz - ctrl *gomock.Controller - idGen *mock.MockIDGenerator - appVarS *mockvar.MockStore - userVarS *mockvar.MockStore - varGetter *mockvar.MockVariablesMetaGetter - modelManage *mockmodel.MockManager - plugin *mockPlugin.MockPluginService - tos *storageMock.MockStorage - knowledge *knowledgemock.MockKnowledge - database *databasemock.MockDatabase - pluginSrv *pluginmock.MockPluginService - internalModel *testutil.UTChatModel - ctx context.Context - closeFn func() + t *testing.T + h *server.Hertz + ctrl *gomock.Controller + idGen *mock.MockIDGenerator + appVarS *mockvar.MockStore + userVarS *mockvar.MockStore + varGetter *mockvar.MockVariablesMetaGetter + modelManage *mockmodel.MockManager + plugin *mockPlugin.MockPluginService + tos *storageMock.MockStorage + knowledge *knowledgemock.MockKnowledge + database *databasemock.MockDatabase + pluginSrv *pluginmock.MockPluginService + internalModel *testutil.UTChatModel + publishPatcher *mockey.Mocker + ctx context.Context + closeFn func() } var req2URL = map[reflect.Type]string{ @@ -256,7 +253,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel) mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build() mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build() - publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() + publishPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() mockCU := mockCrossUser.NewMockUser(ctrl) mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{ @@ -305,9 +302,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { }, nil).Build() f := func() { - if publishPatcher != nil { - publishPatcher.UnPatch() - } + publishPatcher.UnPatch() m.UnPatch() m1.UnPatch() m2.UnPatch() @@ -320,22 +315,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { } return &wfTestRunner{ - t: t, - h: h, - ctrl: ctrl, - idGen: mockIDGen, - appVarS: mockGlobalAppVarStore, - userVarS: mockGlobalUserVarStore, - varGetter: mockVarGetter, - modelManage: mockModelManage, - plugin: mPlugin, - tos: mockTos, - knowledge: mockKwOperator, - database: mockDatabaseOperator, - internalModel: utChatModel, - ctx: context.Background(), - closeFn: f, - pluginSrv: mockPluginSrv, + t: t, + h: h, + ctrl: ctrl, + idGen: mockIDGen, + appVarS: mockGlobalAppVarStore, + userVarS: mockGlobalUserVarStore, + varGetter: mockVarGetter, + modelManage: mockModelManage, + plugin: mPlugin, + tos: mockTos, + knowledge: mockKwOperator, + database: mockDatabaseOperator, + internalModel: utChatModel, + ctx: context.Background(), + closeFn: f, + pluginSrv: mockPluginSrv, + publishPatcher: publishPatcher, } } @@ -4147,14 +4143,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { } - if publishPatcher != nil { - publishPatcher.UnPatch() - } - localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build() - defer func() { - localPatcher.UnPatch() - publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() - }() + defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch() appID := "7513788954458456064" appIDInt64, _ := strconv.ParseInt(appID, 10, 64) @@ -4265,14 +4254,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { return nil } - if publishPatcher != nil { - publishPatcher.UnPatch() - } - localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build() - defer func() { - localPatcher.UnPatch() - publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() - }() + + defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch() defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{ TargetKnowledgeID: 100100, @@ -4313,6 +4296,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { func TestMoveWorkflowAppToLibrary(t *testing.T) { mockey.PatchConvey("test move workflow", t, func() { r := newWfTestRunner(t) + r.publishPatcher.UnPatch() defer r.closeFn() vars := map[string]*vo.TypeInfo{ "app_v1": { @@ -4354,14 +4338,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { } - if publishPatcher != nil { - publishPatcher.UnPatch() - } - localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build() - defer func() { - localPatcher.UnPatch() - publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() - }() + defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch() defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch() defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch() @@ -4479,6 +4456,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { func TestDuplicateWorkflowsByAppID(t *testing.T) { mockey.PatchConvey("test duplicate work", t, func() { r := newWfTestRunner(t) + r.publishPatcher.UnPatch() defer r.closeFn() vars := map[string]*vo.TypeInfo{ @@ -4516,14 +4494,7 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) { } - if publishPatcher != nil { - publishPatcher.UnPatch() - } - localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build() - defer func() { - localPatcher.UnPatch() - publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build() - }() + defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch() appIDInt64 := int64(7513788954458456064) diff --git a/backend/domain/workflow/internal/compose/node_runner.go b/backend/domain/workflow/internal/compose/node_runner.go index 6a0a8921..fa3f4ebe 100644 --- a/backend/domain/workflow/internal/compose/node_runner.go +++ b/backend/domain/workflow/internal/compose/node_runner.go @@ -36,6 +36,7 @@ import ( schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/errorx" + exec "github.com/coze-dev/coze-studio/backend/pkg/execute" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/sonic" @@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any) func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) { var n int64 - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } + var invokeOutput map[string]any - output, err = r.i(ctx, input, opts...) + for { + err = exec.RunWithContextDone(ctx, func() error { + var invokeErr error + invokeOutput, invokeErr = r.i(ctx, input, opts...) + if invokeErr != nil { + return invokeErr + } + return nil + }) if err != nil { - if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry + if _, ok := compose.IsInterruptRerunError(err); ok { r.interrupted = true return nil, err } logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) + if r.maxRetry > n { n++ if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { @@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts . } continue } + return nil, err } + return invokeOutput, nil - return output, nil } } func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) { var n int64 - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } + var streamOutput *schema.StreamReader[map[string]any] + + for { + err = exec.RunWithContextDone(ctx, func() error { + var streamErr error + streamOutput, streamErr = r.s(ctx, input, opts...) + if streamErr != nil { + return streamErr + } + return nil + }) - output, err = r.s(ctx, input, opts...) if err != nil { - if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry + if _, ok := compose.IsInterruptRerunError(err); ok { r.interrupted = true return nil, err } - logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) + logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) if r.maxRetry > n { n++ if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { @@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts . } return nil, err } + return streamOutput, nil - return output, nil } } @@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[ } copied := input.Copy(int(r.maxRetry)) - var n int64 + defer func() { for i := n + 1; i < r.maxRetry; i++ { copied[i].Close() } }() - + var collectOutput map[string]any for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } + err = exec.RunWithContextDone(ctx, func() error { + var collectErr error + collectOutput, collectErr = r.c(ctx, copied[n], opts...) + if collectErr != nil { + return collectErr + } + return nil + }) - output, err = r.c(ctx, copied[n], opts...) if err != nil { - if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry + if _, ok := compose.IsInterruptRerunError(err); ok { r.interrupted = true return nil, err } - logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) + logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) if r.maxRetry > n { n++ if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { @@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[ } continue } + return nil, err } - - return output, nil + return collectOutput, nil } } @@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade } }() + var transformOutput *schema.StreamReader[map[string]any] for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - output, err = r.t(ctx, copied[n], opts...) + err = exec.RunWithContextDone(ctx, func() error { + var transformErr error + transformOutput, transformErr = r.t(ctx, copied[n], opts...) + if transformErr != nil { + return transformErr + } + return nil + }) if err != nil { - if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry + if _, ok := compose.IsInterruptRerunError(err); ok { r.interrupted = true return nil, err } - - logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) + logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err) if r.maxRetry > n { n++ if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil { @@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade return nil, err } - return output, nil + return transformOutput, nil + } } diff --git a/backend/domain/workflow/internal/execute/collect_token.go b/backend/domain/workflow/internal/execute/collect_token.go index 01dae852..80023fcd 100644 --- a/backend/domain/workflow/internal/execute/collect_token.go +++ b/backend/domain/workflow/internal/execute/collect_token.go @@ -57,17 +57,25 @@ func (t *TokenCollector) addTokenUsage(usage *model.TokenUsage) { } func (t *TokenCollector) wait() *model.TokenUsage { - t.wg.Wait() t.mu.Lock() + defer t.mu.Unlock() + t.wg.Wait() usage := &model.TokenUsage{ PromptTokens: t.Usage.PromptTokens, CompletionTokens: t.Usage.CompletionTokens, TotalTokens: t.Usage.TotalTokens, } - t.mu.Unlock() + return usage } +func (t *TokenCollector) add(i int) { + t.mu.Lock() + defer t.mu.Unlock() + t.wg.Add(i) + return +} + func getTokenCollector(ctx context.Context) *TokenCollector { c := GetExeCtx(ctx) if c == nil { @@ -83,7 +91,8 @@ func GetTokenCallbackHandler() callbacks.Handler { if c == nil { return ctx } - c.wg.Add(1) + c.add(1) + //c.wg.Add(1) return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { @@ -122,12 +131,16 @@ func GetTokenCallbackHandler() callbacks.Handler { if chunk.TokenUsage == nil { continue } + // 在goroutine内部累加,避免并发访问 newC.PromptTokens += chunk.TokenUsage.PromptTokens newC.CompletionTokens += chunk.TokenUsage.CompletionTokens newC.TotalTokens += chunk.TokenUsage.TotalTokens } - c.addTokenUsage(newC) + // 只在最后调用一次addTokenUsage,减少锁竞争 + if newC.TotalTokens > 0 { + c.addTokenUsage(newC) + } }) return ctx }, diff --git a/backend/go.mod b/backend/go.mod index 118e8f48..6c431c20 100755 --- a/backend/go.mod +++ b/backend/go.mod @@ -56,6 +56,7 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1 github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 + github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09 github.com/cloudwego/eino-ext/components/model/gemini v0.1.2 @@ -85,7 +86,6 @@ require ( github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.37 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.5 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 // indirect - github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e // indirect github.com/cloudwego/gopkg v0.1.4 // indirect github.com/evanphx/json-patch v4.12.0+incompatible // indirect github.com/extrame/ole2 v0.0.0-20160812065207-d69429661ad7 // indirect diff --git a/backend/go.sum b/backend/go.sum index 31515acd..b043f531 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -2614,8 +2614,6 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genai v1.13.0 h1:LRhwx5PU+bXhfnXyPEHu2kt9yc+MpvuYbajxSorOJjg= -google.golang.org/genai v1.13.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= google.golang.org/genai v1.18.0 h1:fTmK7y30CO0CL8xRyyFSjTkd1MNbYUeFUehvDyU/2gQ= google.golang.org/genai v1.18.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= diff --git a/backend/pkg/execute/execute.go b/backend/pkg/execute/execute.go new file mode 100644 index 00000000..cb5a3167 --- /dev/null +++ b/backend/pkg/execute/execute.go @@ -0,0 +1,44 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package execute + +import ( + "context" + "fmt" + "runtime/debug" +) + +func RunWithContextDone(ctx context.Context, fn func() error) error { + errChan := make(chan error, 1) + go func() { + defer func() { + if err := recover(); err != nil { + errChan <- fmt.Errorf("exec func panic, %v \n %s", err, debug.Stack()) + } + close(errChan) + }() + err := fn() + errChan <- err + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errChan: + return err + } +}