diff --git a/backend/application/user/user.go b/backend/application/user/user.go index 002eb7a7..a05678bf 100644 --- a/backend/application/user/user.go +++ b/backend/application/user/user.go @@ -19,7 +19,10 @@ package user import ( "context" "net/mail" + "os" + "slices" "strconv" + "strings" "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api" "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground" @@ -30,7 +33,8 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/storage" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" - "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" + langSlices "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" + "github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -56,6 +60,11 @@ func (u *UserApplicationService) PassportWebEmailRegisterV2(ctx context.Context, return nil, "", errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", "Invalid email")) } + // Allow Register Checker + if !u.allowRegisterChecker(req.GetEmail()) { + return nil, "", errorx.New(errno.ErrNotAllowedRegisterCode) + } + userInfo, err := u.DomainSVC.Create(ctx, &user.CreateUserRequest{ Email: req.GetEmail(), Password: req.GetPassword(), @@ -77,6 +86,20 @@ func (u *UserApplicationService) PassportWebEmailRegisterV2(ctx context.Context, }, userInfo.SessionKey, nil } +func (u *UserApplicationService) allowRegisterChecker(email string) bool { + disableUserRegistration := os.Getenv(consts.DisableUserRegistration) + if strings.ToLower(disableUserRegistration) != "true" { + return true + } + + allowedEmails := os.Getenv(consts.AllowRegistrationEmail) + if allowedEmails == "" { + return false + } + + return slices.Contains(strings.Split(allowedEmails, ","), strings.ToLower(email)) +} + // PassportWebLogoutGet 处理用户登出请求 func (u *UserApplicationService) PassportWebLogoutGet(ctx context.Context, req *passport.PassportWebLogoutGetRequest) ( resp *passport.PassportWebLogoutGetResponse, err error, @@ -204,7 +227,7 @@ func (u *UserApplicationService) GetSpaceListV2(ctx context.Context, req *playgr return nil, err } - botSpaces := slices.Transform(spaces, func(space *entity.Space) *playground.BotSpaceV2 { + botSpaces := langSlices.Transform(spaces, func(space *entity.Space) *playground.BotSpaceV2 { return &playground.BotSpaceV2{ ID: space.ID, Name: space.Name, @@ -230,7 +253,7 @@ func (u *UserApplicationService) GetSpaceListV2(ctx context.Context, req *playgr func (u *UserApplicationService) MGetUserBasicInfo(ctx context.Context, req *playground.MGetUserBasicInfoRequest) ( resp *playground.MGetUserBasicInfoResponse, err error, ) { - userIDs, err := slices.TransformWithErrorCheck(req.GetUserIds(), func(s string) (int64, error) { + userIDs, err := langSlices.TransformWithErrorCheck(req.GetUserIds(), func(s string) (int64, error) { return strconv.ParseInt(s, 10, 64) }) if err != nil { @@ -243,7 +266,7 @@ func (u *UserApplicationService) MGetUserBasicInfo(ctx context.Context, req *pla } return &playground.MGetUserBasicInfoResponse{ - UserBasicInfoMap: slices.ToMap(userInfos, func(userInfo *entity.User) (string, *playground.UserBasicInfo) { + UserBasicInfoMap: langSlices.ToMap(userInfos, func(userInfo *entity.User) (string, *playground.UserBasicInfo) { return strconv.FormatInt(userInfo.UserID, 10), userDo2PlaygroundTo(userInfo) }), Code: 0, diff --git a/backend/types/consts/consts.go b/backend/types/consts/consts.go index b4a6597d..f1fb0403 100644 --- a/backend/types/consts/consts.go +++ b/backend/types/consts/consts.go @@ -112,3 +112,8 @@ const ( ApplyUploadActionURI = "/api/common/upload/apply_upload_action" UploadURI = "/api/common/upload" ) + +const ( + DisableUserRegistration = "DISABLE_USER_REGISTRATION" + AllowRegistrationEmail = "ALLOW_REGISTRATION_EMAIL" +) diff --git a/backend/types/errno/user.go b/backend/types/errno/user.go index 4b24d59b..61e96928 100644 --- a/backend/types/errno/user.go +++ b/backend/types/errno/user.go @@ -31,9 +31,17 @@ const ( ErrUserResourceNotFound = 700000005 ErrUserInvalidParamCode = 700000006 ErrUserPermissionCode = 700000007 + ErrNotAllowedRegisterCode = 700000008 ) func init() { + + code.Register( + ErrNotAllowedRegisterCode, + "The user registration has been disabled by the administrator. Please contact the administrator!", + code.WithAffectStability(false), + ) + code.Register( ErrUserPermissionCode, "unauthorized access : {msg}", diff --git a/docker/.env.example b/docker/.env.example index ffa7599c..a82a892a 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -198,3 +198,9 @@ export CODE_RUNNER_NODE_MODULES_DIR="" export CODE_RUNNER_TIMEOUT_SECONDS="" # Code execution memory limit, default 100MB. e.g. "256" export CODE_RUNNER_MEMORY_LIMIT_MB="" + +# The function of registration controller +# If you want to disable the registration feature, set DISABLE_USER_REGISTRATION to true. You can then control allowed registrations via a whitelist with ALLOW_REGISTRATION_EMAIL. +export DISABLE_USER_REGISTRATION="" # default "", if you want to disable, set to true +export ALLOW_REGISTRATION_EMAIL="" # is a list of email addresses, separated by ",". Example: "11@example.com,22@example.com" +