coze-studio/backend/pkg/taskgroup/taskgroup.go

82 lines
1.9 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package taskgroup
import (
"context"
"sync/atomic"
"golang.org/x/sync/errgroup"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type TaskGroup interface {
Go(f func() error)
Wait() error
}
type taskGroup struct {
errGroup *errgroup.Group
ctx context.Context
execAllTask atomic.Bool
}
// NewTaskGroup if one task return error, the rest task will stop
func NewTaskGroup(ctx context.Context, concurrentCount int) TaskGroup {
t := &taskGroup{}
t.errGroup, t.ctx = errgroup.WithContext(ctx)
t.errGroup.SetLimit(concurrentCount)
t.execAllTask.Store(false)
return t
}
// NewUninterruptibleTaskGroup if one task return error, the rest task will continue
func NewUninterruptibleTaskGroup(ctx context.Context, concurrentCount int) TaskGroup {
t := &taskGroup{}
t.errGroup, t.ctx = errgroup.WithContext(ctx)
t.errGroup.SetLimit(concurrentCount)
t.execAllTask.Store(true)
return t
}
func (t *taskGroup) Go(f func() error) {
t.errGroup.Go(func() error {
defer func() {
if err := recover(); err != nil {
logs.CtxErrorf(t.ctx, "[TaskGroup] exec panic recover:%+v", err)
}
}()
if !t.execAllTask.Load() {
select {
case <-t.ctx.Done():
return t.ctx.Err()
default:
}
}
return f()
})
}
func (t *taskGroup) Wait() error {
return t.errGroup.Wait()
}