fix: context cancel not working during node runner execution (#819)
This commit is contained in:
parent
09d00c26cb
commit
19c63a1150
|
|
@ -106,10 +106,6 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
publishPatcher *mockey.Mocker
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
|
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
|
||||||
service.RegisterAllNodeAdaptors()
|
service.RegisterAllNodeAdaptors()
|
||||||
|
|
@ -117,22 +113,23 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type wfTestRunner struct {
|
type wfTestRunner struct {
|
||||||
t *testing.T
|
t *testing.T
|
||||||
h *server.Hertz
|
h *server.Hertz
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
idGen *mock.MockIDGenerator
|
idGen *mock.MockIDGenerator
|
||||||
appVarS *mockvar.MockStore
|
appVarS *mockvar.MockStore
|
||||||
userVarS *mockvar.MockStore
|
userVarS *mockvar.MockStore
|
||||||
varGetter *mockvar.MockVariablesMetaGetter
|
varGetter *mockvar.MockVariablesMetaGetter
|
||||||
modelManage *mockmodel.MockManager
|
modelManage *mockmodel.MockManager
|
||||||
plugin *mockPlugin.MockPluginService
|
plugin *mockPlugin.MockPluginService
|
||||||
tos *storageMock.MockStorage
|
tos *storageMock.MockStorage
|
||||||
knowledge *knowledgemock.MockKnowledge
|
knowledge *knowledgemock.MockKnowledge
|
||||||
database *databasemock.MockDatabase
|
database *databasemock.MockDatabase
|
||||||
pluginSrv *pluginmock.MockPluginService
|
pluginSrv *pluginmock.MockPluginService
|
||||||
internalModel *testutil.UTChatModel
|
internalModel *testutil.UTChatModel
|
||||||
ctx context.Context
|
publishPatcher *mockey.Mocker
|
||||||
closeFn func()
|
ctx context.Context
|
||||||
|
closeFn func()
|
||||||
}
|
}
|
||||||
|
|
||||||
var req2URL = map[reflect.Type]string{
|
var req2URL = map[reflect.Type]string{
|
||||||
|
|
@ -256,7 +253,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||||
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
|
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
|
||||||
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
|
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
|
||||||
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
|
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
|
||||||
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
publishPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
||||||
|
|
||||||
mockCU := mockCrossUser.NewMockUser(ctrl)
|
mockCU := mockCrossUser.NewMockUser(ctrl)
|
||||||
mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{
|
mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{
|
||||||
|
|
@ -305,9 +302,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||||
}, nil).Build()
|
}, nil).Build()
|
||||||
|
|
||||||
f := func() {
|
f := func() {
|
||||||
if publishPatcher != nil {
|
publishPatcher.UnPatch()
|
||||||
publishPatcher.UnPatch()
|
|
||||||
}
|
|
||||||
m.UnPatch()
|
m.UnPatch()
|
||||||
m1.UnPatch()
|
m1.UnPatch()
|
||||||
m2.UnPatch()
|
m2.UnPatch()
|
||||||
|
|
@ -320,22 +315,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &wfTestRunner{
|
return &wfTestRunner{
|
||||||
t: t,
|
t: t,
|
||||||
h: h,
|
h: h,
|
||||||
ctrl: ctrl,
|
ctrl: ctrl,
|
||||||
idGen: mockIDGen,
|
idGen: mockIDGen,
|
||||||
appVarS: mockGlobalAppVarStore,
|
appVarS: mockGlobalAppVarStore,
|
||||||
userVarS: mockGlobalUserVarStore,
|
userVarS: mockGlobalUserVarStore,
|
||||||
varGetter: mockVarGetter,
|
varGetter: mockVarGetter,
|
||||||
modelManage: mockModelManage,
|
modelManage: mockModelManage,
|
||||||
plugin: mPlugin,
|
plugin: mPlugin,
|
||||||
tos: mockTos,
|
tos: mockTos,
|
||||||
knowledge: mockKwOperator,
|
knowledge: mockKwOperator,
|
||||||
database: mockDatabaseOperator,
|
database: mockDatabaseOperator,
|
||||||
internalModel: utChatModel,
|
internalModel: utChatModel,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
closeFn: f,
|
closeFn: f,
|
||||||
pluginSrv: mockPluginSrv,
|
pluginSrv: mockPluginSrv,
|
||||||
|
publishPatcher: publishPatcher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4147,14 +4143,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if publishPatcher != nil {
|
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||||
publishPatcher.UnPatch()
|
|
||||||
}
|
|
||||||
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
|
|
||||||
defer func() {
|
|
||||||
localPatcher.UnPatch()
|
|
||||||
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
|
||||||
}()
|
|
||||||
|
|
||||||
appID := "7513788954458456064"
|
appID := "7513788954458456064"
|
||||||
appIDInt64, _ := strconv.ParseInt(appID, 10, 64)
|
appIDInt64, _ := strconv.ParseInt(appID, 10, 64)
|
||||||
|
|
@ -4265,14 +4254,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
if publishPatcher != nil {
|
|
||||||
publishPatcher.UnPatch()
|
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||||
}
|
|
||||||
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
|
|
||||||
defer func() {
|
|
||||||
localPatcher.UnPatch()
|
|
||||||
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{
|
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{
|
||||||
TargetKnowledgeID: 100100,
|
TargetKnowledgeID: 100100,
|
||||||
|
|
@ -4313,6 +4296,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||||
func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||||
mockey.PatchConvey("test move workflow", t, func() {
|
mockey.PatchConvey("test move workflow", t, func() {
|
||||||
r := newWfTestRunner(t)
|
r := newWfTestRunner(t)
|
||||||
|
r.publishPatcher.UnPatch()
|
||||||
defer r.closeFn()
|
defer r.closeFn()
|
||||||
vars := map[string]*vo.TypeInfo{
|
vars := map[string]*vo.TypeInfo{
|
||||||
"app_v1": {
|
"app_v1": {
|
||||||
|
|
@ -4354,14 +4338,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if publishPatcher != nil {
|
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||||
publishPatcher.UnPatch()
|
|
||||||
}
|
|
||||||
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
|
|
||||||
defer func() {
|
|
||||||
localPatcher.UnPatch()
|
|
||||||
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch()
|
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch()
|
||||||
defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch()
|
defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch()
|
||||||
|
|
@ -4479,6 +4456,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||||
func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
||||||
mockey.PatchConvey("test duplicate work", t, func() {
|
mockey.PatchConvey("test duplicate work", t, func() {
|
||||||
r := newWfTestRunner(t)
|
r := newWfTestRunner(t)
|
||||||
|
r.publishPatcher.UnPatch()
|
||||||
defer r.closeFn()
|
defer r.closeFn()
|
||||||
|
|
||||||
vars := map[string]*vo.TypeInfo{
|
vars := map[string]*vo.TypeInfo{
|
||||||
|
|
@ -4516,14 +4494,7 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if publishPatcher != nil {
|
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||||
publishPatcher.UnPatch()
|
|
||||||
}
|
|
||||||
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
|
|
||||||
defer func() {
|
|
||||||
localPatcher.UnPatch()
|
|
||||||
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
|
||||||
}()
|
|
||||||
|
|
||||||
appIDInt64 := int64(7513788954458456064)
|
appIDInt64 := int64(7513788954458456064)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ import (
|
||||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||||
|
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||||
|
|
@ -614,21 +615,25 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
|
||||||
|
|
||||||
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||||
var n int64
|
var n int64
|
||||||
for {
|
var invokeOutput map[string]any
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err = r.i(ctx, input, opts...)
|
for {
|
||||||
|
err = exec.RunWithContextDone(ctx, func() error {
|
||||||
|
var invokeErr error
|
||||||
|
invokeOutput, invokeErr = r.i(ctx, input, opts...)
|
||||||
|
if invokeErr != nil {
|
||||||
|
return invokeErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||||
r.interrupted = true
|
r.interrupted = true
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||||
|
|
||||||
if r.maxRetry > n {
|
if r.maxRetry > n {
|
||||||
n++
|
n++
|
||||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||||
|
|
@ -636,30 +641,35 @@ func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts .
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return invokeOutput, nil
|
||||||
|
|
||||||
return output, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||||
var n int64
|
var n int64
|
||||||
for {
|
var streamOutput *schema.StreamReader[map[string]any]
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
for {
|
||||||
return nil, ctx.Err()
|
err = exec.RunWithContextDone(ctx, func() error {
|
||||||
default:
|
var streamErr error
|
||||||
}
|
streamOutput, streamErr = r.s(ctx, input, opts...)
|
||||||
|
if streamErr != nil {
|
||||||
|
return streamErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
output, err = r.s(ctx, input, opts...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||||
r.interrupted = true
|
r.interrupted = true
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||||
if r.maxRetry > n {
|
if r.maxRetry > n {
|
||||||
n++
|
n++
|
||||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||||
|
|
@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return streamOutput, nil
|
||||||
|
|
||||||
return output, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -680,29 +690,31 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
|
||||||
}
|
}
|
||||||
|
|
||||||
copied := input.Copy(int(r.maxRetry))
|
copied := input.Copy(int(r.maxRetry))
|
||||||
|
|
||||||
var n int64
|
var n int64
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
for i := n + 1; i < r.maxRetry; i++ {
|
for i := n + 1; i < r.maxRetry; i++ {
|
||||||
copied[i].Close()
|
copied[i].Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
var collectOutput map[string]any
|
||||||
for {
|
for {
|
||||||
select {
|
err = exec.RunWithContextDone(ctx, func() error {
|
||||||
case <-ctx.Done():
|
var collectErr error
|
||||||
return nil, ctx.Err()
|
collectOutput, collectErr = r.c(ctx, copied[n], opts...)
|
||||||
default:
|
if collectErr != nil {
|
||||||
}
|
return collectErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
output, err = r.c(ctx, copied[n], opts...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||||
r.interrupted = true
|
r.interrupted = true
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||||
if r.maxRetry > n {
|
if r.maxRetry > n {
|
||||||
n++
|
n++
|
||||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||||
|
|
@ -710,10 +722,10 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return collectOutput, nil
|
||||||
return output, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var transformOutput *schema.StreamReader[map[string]any]
|
||||||
for {
|
for {
|
||||||
select {
|
err = exec.RunWithContextDone(ctx, func() error {
|
||||||
case <-ctx.Done():
|
var transformErr error
|
||||||
return nil, ctx.Err()
|
transformOutput, transformErr = r.t(ctx, copied[n], opts...)
|
||||||
default:
|
if transformErr != nil {
|
||||||
}
|
return transformErr
|
||||||
|
}
|
||||||
output, err = r.t(ctx, copied[n], opts...)
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
if _, ok := compose.IsInterruptRerunError(err); ok {
|
||||||
r.interrupted = true
|
r.interrupted = true
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
|
||||||
if r.maxRetry > n {
|
if r.maxRetry > n {
|
||||||
n++
|
n++
|
||||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||||
|
|
@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return output, nil
|
return transformOutput, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,17 +57,25 @@ func (t *TokenCollector) addTokenUsage(usage *model.TokenUsage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TokenCollector) wait() *model.TokenUsage {
|
func (t *TokenCollector) wait() *model.TokenUsage {
|
||||||
t.wg.Wait()
|
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
t.wg.Wait()
|
||||||
usage := &model.TokenUsage{
|
usage := &model.TokenUsage{
|
||||||
PromptTokens: t.Usage.PromptTokens,
|
PromptTokens: t.Usage.PromptTokens,
|
||||||
CompletionTokens: t.Usage.CompletionTokens,
|
CompletionTokens: t.Usage.CompletionTokens,
|
||||||
TotalTokens: t.Usage.TotalTokens,
|
TotalTokens: t.Usage.TotalTokens,
|
||||||
}
|
}
|
||||||
t.mu.Unlock()
|
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TokenCollector) add(i int) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
t.wg.Add(i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func getTokenCollector(ctx context.Context) *TokenCollector {
|
func getTokenCollector(ctx context.Context) *TokenCollector {
|
||||||
c := GetExeCtx(ctx)
|
c := GetExeCtx(ctx)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
|
|
@ -83,7 +91,8 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
c.wg.Add(1)
|
c.add(1)
|
||||||
|
//c.wg.Add(1)
|
||||||
return ctx
|
return ctx
|
||||||
},
|
},
|
||||||
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
|
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
|
||||||
|
|
@ -122,12 +131,16 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||||
if chunk.TokenUsage == nil {
|
if chunk.TokenUsage == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// 在goroutine内部累加,避免并发访问
|
||||||
newC.PromptTokens += chunk.TokenUsage.PromptTokens
|
newC.PromptTokens += chunk.TokenUsage.PromptTokens
|
||||||
newC.CompletionTokens += chunk.TokenUsage.CompletionTokens
|
newC.CompletionTokens += chunk.TokenUsage.CompletionTokens
|
||||||
newC.TotalTokens += chunk.TokenUsage.TotalTokens
|
newC.TotalTokens += chunk.TokenUsage.TotalTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
c.addTokenUsage(newC)
|
// 只在最后调用一次addTokenUsage,减少锁竞争
|
||||||
|
if newC.TotalTokens > 0 {
|
||||||
|
c.addTokenUsage(newC)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return ctx
|
return ctx
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ require (
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1
|
||||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0
|
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0
|
||||||
|
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e
|
||||||
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
|
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
|
||||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09
|
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09
|
||||||
github.com/cloudwego/eino-ext/components/model/gemini v0.1.2
|
github.com/cloudwego/eino-ext/components/model/gemini v0.1.2
|
||||||
|
|
@ -85,7 +86,6 @@ require (
|
||||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.37 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.37 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.5 // indirect
|
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.5 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 // indirect
|
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 // indirect
|
||||||
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e // indirect
|
|
||||||
github.com/cloudwego/gopkg v0.1.4 // indirect
|
github.com/cloudwego/gopkg v0.1.4 // indirect
|
||||||
github.com/evanphx/json-patch v4.12.0+incompatible // indirect
|
github.com/evanphx/json-patch v4.12.0+incompatible // indirect
|
||||||
github.com/extrame/ole2 v0.0.0-20160812065207-d69429661ad7 // indirect
|
github.com/extrame/ole2 v0.0.0-20160812065207-d69429661ad7 // indirect
|
||||||
|
|
|
||||||
|
|
@ -2614,8 +2614,6 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww
|
||||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||||
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||||
google.golang.org/genai v1.13.0 h1:LRhwx5PU+bXhfnXyPEHu2kt9yc+MpvuYbajxSorOJjg=
|
|
||||||
google.golang.org/genai v1.13.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M=
|
|
||||||
google.golang.org/genai v1.18.0 h1:fTmK7y30CO0CL8xRyyFSjTkd1MNbYUeFUehvDyU/2gQ=
|
google.golang.org/genai v1.18.0 h1:fTmK7y30CO0CL8xRyyFSjTkd1MNbYUeFUehvDyU/2gQ=
|
||||||
google.golang.org/genai v1.18.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M=
|
google.golang.org/genai v1.18.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M=
|
||||||
google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 coze-dev Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package execute
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RunWithContextDone(ctx context.Context, fn func() error) error {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
errChan <- fmt.Errorf("exec func panic, %v \n %s", err, debug.Stack())
|
||||||
|
}
|
||||||
|
close(errChan)
|
||||||
|
}()
|
||||||
|
err := fn()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-errChan:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue