82 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			82 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package taskgroup
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"sync/atomic"
 | |
| 
 | |
| 	"golang.org/x/sync/errgroup"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | |
| )
 | |
| 
 | |
| type TaskGroup interface {
 | |
| 	Go(f func() error)
 | |
| 	Wait() error
 | |
| }
 | |
| 
 | |
| type taskGroup struct {
 | |
| 	errGroup    *errgroup.Group
 | |
| 	ctx         context.Context
 | |
| 	execAllTask atomic.Bool
 | |
| }
 | |
| 
 | |
| // NewTaskGroup if one task return error, the rest task will stop
 | |
| func NewTaskGroup(ctx context.Context, concurrentCount int) TaskGroup {
 | |
| 	t := &taskGroup{}
 | |
| 	t.errGroup, t.ctx = errgroup.WithContext(ctx)
 | |
| 	t.errGroup.SetLimit(concurrentCount)
 | |
| 	t.execAllTask.Store(false)
 | |
| 
 | |
| 	return t
 | |
| }
 | |
| 
 | |
| // NewUninterruptibleTaskGroup if one task return error, the rest task will continue
 | |
| func NewUninterruptibleTaskGroup(ctx context.Context, concurrentCount int) TaskGroup {
 | |
| 	t := &taskGroup{}
 | |
| 	t.errGroup, t.ctx = errgroup.WithContext(ctx)
 | |
| 	t.errGroup.SetLimit(concurrentCount)
 | |
| 	t.execAllTask.Store(true)
 | |
| 
 | |
| 	return t
 | |
| }
 | |
| 
 | |
| func (t *taskGroup) Go(f func() error) {
 | |
| 	t.errGroup.Go(func() error {
 | |
| 		defer func() {
 | |
| 			if err := recover(); err != nil {
 | |
| 				logs.CtxErrorf(t.ctx, "[TaskGroup] exec panic recover:%+v", err)
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		if !t.execAllTask.Load() {
 | |
| 			select {
 | |
| 			case <-t.ctx.Done():
 | |
| 				return t.ctx.Err()
 | |
| 			default:
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		return f()
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func (t *taskGroup) Wait() error {
 | |
| 	return t.errGroup.Wait()
 | |
| }
 |