/* * Copyright 2025 coze-dev Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package service import ( "context" "crypto/hmac" "crypto/rand" "crypto/sha256" "crypto/subtle" "encoding/base64" "encoding/json" "fmt" "strconv" "strings" "time" "unicode/utf8" "golang.org/x/crypto/argon2" uploadEntity "github.com/coze-dev/coze-studio/backend/domain/upload/entity" userEntity "github.com/coze-dev/coze-studio/backend/domain/user/entity" "github.com/coze-dev/coze-studio/backend/domain/user/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/user/repository" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/errno" ) type Components struct { IconOSS storage.Storage IDGen idgen.IDGenerator UserRepo repository.UserRepository SpaceRepo repository.SpaceRepository } func NewUserDomain(ctx context.Context, c *Components) User { return &userImpl{ Components: c, } } type userImpl struct { *Components } func (u *userImpl) Login(ctx context.Context, email, password string) (user *userEntity.User, err error) { userModel, exist, err := u.UserRepo.GetUsersByEmail(ctx, email) if err != nil { return nil, err } if !exist { return nil, errorx.New(errno.ErrUserInfoInvalidateCode) } // 验证密码,使用 Argon2id 算法 valid, err := verifyPassword(password, userModel.Password) if err != nil { return nil, err } if !valid { return nil, errorx.New(errno.ErrUserInfoInvalidateCode) } uniqueSessionID, err := u.IDGen.GenID(ctx) if err != nil { return nil, fmt.Errorf("failed to generate session id: %w", err) } sessionKey, err := generateSessionKey(uniqueSessionID) if err != nil { return nil, err } // 更新用户会话密钥 err = u.UserRepo.UpdateSessionKey(ctx, userModel.ID, sessionKey) if err != nil { return nil, err } userModel.SessionKey = sessionKey resURL, err := u.IconOSS.GetObjectUrl(ctx, userModel.IconURI) if err != nil { return nil, err } return userPo2Do(userModel, resURL), nil } func (u *userImpl) Logout(ctx context.Context, userID int64) (err error) { err = u.UserRepo.ClearSessionKey(ctx, userID) if err != nil { return err } return nil } func (u *userImpl) ResetPassword(ctx context.Context, email, password string) (err error) { // 使用 Argon2id 算法对密码进行哈希处理 hashedPassword, err := hashPassword(password) if err != nil { return err } err = u.UserRepo.UpdatePassword(ctx, email, hashedPassword) if err != nil { return err } return nil } func (u *userImpl) GetUserInfo(ctx context.Context, userID int64) (resp *userEntity.User, err error) { if userID <= 0 { return nil, errorx.New(errno.ErrUserInvalidParamCode, errorx.KVf("msg", "invalid user id : %d", userID)) } userModel, err := u.UserRepo.GetUserByID(ctx, userID) if err != nil { return nil, err } resURL, err := u.IconOSS.GetObjectUrl(ctx, userModel.IconURI) if err != nil { return nil, err } return userPo2Do(userModel, resURL), nil } func (u *userImpl) UpdateAvatar(ctx context.Context, userID int64, ext string, imagePayload []byte) (url string, err error) { avatarKey := "user_avatar/" + strconv.FormatInt(userID, 10) + "." + ext err = u.IconOSS.PutObject(ctx, avatarKey, imagePayload) if err != nil { return "", err } err = u.UserRepo.UpdateAvatar(ctx, userID, avatarKey) if err != nil { return "", err } url, err = u.IconOSS.GetObjectUrl(ctx, avatarKey) if err != nil { return "", err } return url, nil } func (u *userImpl) ValidateProfileUpdate(ctx context.Context, req *ValidateProfileUpdateRequest) ( resp *ValidateProfileUpdateResponse, err error, ) { if req.UniqueName == nil && req.Email == nil { return nil, errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", "missing parameter")) } if req.UniqueName != nil { uniqueName := ptr.From(req.UniqueName) charNum := utf8.RuneCountInString(uniqueName) if charNum < 4 || charNum > 20 { return &ValidateProfileUpdateResponse{ Code: UniqueNameTooShortOrTooLong, Msg: "unique name length should be between 4 and 20", }, nil } exist, err := u.UserRepo.CheckUniqueNameExist(ctx, uniqueName) if err != nil { return nil, err } if exist { return &ValidateProfileUpdateResponse{ Code: UniqueNameExist, Msg: "unique name existed", }, nil } } return &ValidateProfileUpdateResponse{ Code: ValidateSuccess, Msg: "success", }, nil } func (u *userImpl) UpdateProfile(ctx context.Context, req *UpdateProfileRequest) error { updates := map[string]interface{}{ "updated_at": time.Now().UnixMilli(), } if req.UniqueName != nil { resp, err := u.ValidateProfileUpdate(ctx, &ValidateProfileUpdateRequest{ UniqueName: req.UniqueName, }) if err != nil { return err } if resp.Code != ValidateSuccess { return errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", resp.Msg)) } updates["unique_name"] = ptr.From(req.UniqueName) } if req.Name != nil { updates["name"] = ptr.From(req.Name) } if req.Description != nil { updates["description"] = ptr.From(req.Description) } if req.Locale != nil { updates["locale"] = ptr.From(req.Locale) } err := u.UserRepo.UpdateProfile(ctx, req.UserID, updates) if err != nil { return err } return nil } func (u *userImpl) Create(ctx context.Context, req *CreateUserRequest) (user *userEntity.User, err error) { exist, err := u.UserRepo.CheckEmailExist(ctx, req.Email) if err != nil { return nil, err } if exist { return nil, errorx.New(errno.ErrUserEmailAlreadyExistCode, errorx.KV("email", req.Email)) } if req.UniqueName != "" { exist, err = u.UserRepo.CheckUniqueNameExist(ctx, req.UniqueName) if err != nil { return nil, err } if exist { return nil, errorx.New(errno.ErrUserUniqueNameAlreadyExistCode, errorx.KV("name", req.UniqueName)) } } // 使用 Argon2id 算法对密码进行哈希处理 hashedPassword, err := hashPassword(req.Password) if err != nil { return nil, err } name := req.Name if name == "" { name = strings.Split(req.Email, "@")[0] } userID, err := u.IDGen.GenID(ctx) if err != nil { return nil, fmt.Errorf("generate id error: %w", err) } now := time.Now().UnixMilli() spaceID := req.SpaceID if spaceID <= 0 { var sid int64 sid, err = u.IDGen.GenID(ctx) if err != nil { return nil, fmt.Errorf("gen space_id failed: %w", err) } err = u.SpaceRepo.CreateSpace(ctx, &model.Space{ ID: sid, Name: "Personal Space", Description: "This is your personal space", IconURI: uploadEntity.EnterpriseIconURI, OwnerID: userID, CreatorID: userID, CreatedAt: now, UpdatedAt: now, }) if err != nil { return nil, fmt.Errorf("create personal space failed: %w", err) } spaceID = sid } newUser := &model.User{ ID: userID, IconURI: uploadEntity.UserIconURI, Name: name, UniqueName: u.getUniqueNameFormEmail(ctx, req.Email), Email: req.Email, Password: hashedPassword, Description: req.Description, UserVerified: false, Locale: req.Locale, CreatedAt: now, UpdatedAt: now, } err = u.UserRepo.CreateUser(ctx, newUser) if err != nil { return nil, fmt.Errorf("insert user failed: %w", err) } err = u.SpaceRepo.AddSpaceUser(ctx, &model.SpaceUser{ SpaceID: spaceID, UserID: userID, RoleType: 1, CreatedAt: now, UpdatedAt: now, }) if err != nil { return nil, fmt.Errorf("add space user failed: %w", err) } iconURL, err := u.IconOSS.GetObjectUrl(ctx, newUser.IconURI) if err != nil { return nil, fmt.Errorf("get icon url failed: %w", err) } return userPo2Do(newUser, iconURL), nil } func (u *userImpl) getUniqueNameFormEmail(ctx context.Context, email string) string { arr := strings.Split(email, "@") if len(arr) != 2 { return email } username := arr[0] exist, err := u.UserRepo.CheckUniqueNameExist(ctx, username) if err != nil { logs.CtxWarnf(ctx, "check unique name exist failed: %v", err) return email } if exist { logs.CtxWarnf(ctx, "unique name %s already exist", username) return email } return username } func (u *userImpl) ValidateSession(ctx context.Context, sessionKey string) ( session *userEntity.Session, exist bool, err error, ) { // 验证会话密钥 sessionModel, err := verifySessionKey(sessionKey) if err != nil { return nil, false, errorx.New(errno.ErrUserAuthenticationFailed, errorx.KV("reason", "access denied")) } // 从数据库获取用户信息 userModel, exist, err := u.UserRepo.GetUserBySessionKey(ctx, sessionKey) if err != nil { return nil, false, err } if !exist { return nil, false, nil } return &userEntity.Session{ UserID: userModel.ID, Locale: userModel.Locale, CreatedAt: sessionModel.CreatedAt, ExpiresAt: sessionModel.ExpiresAt, }, true, nil } func (u *userImpl) MGetUserProfiles(ctx context.Context, userIDs []int64) (users []*userEntity.User, err error) { userModels, err := u.UserRepo.GetUsersByIDs(ctx, userIDs) if err != nil { return nil, err } users = make([]*userEntity.User, 0, len(userModels)) for _, um := range userModels { // 获取图片URL resURL, err := u.IconOSS.GetObjectUrl(ctx, um.IconURI) if err != nil { continue // 如果获取图片URL失败,跳过该用户 } users = append(users, userPo2Do(um, resURL)) } return users, nil } func (u *userImpl) GetUserProfiles(ctx context.Context, userID int64) (user *userEntity.User, err error) { userInfos, err := u.MGetUserProfiles(ctx, []int64{userID}) if err != nil { return nil, err } if len(userInfos) == 0 { return nil, errorx.New(errno.ErrUserResourceNotFound, errorx.KV("type", "user"), errorx.KV("id", conv.Int64ToStr(userID))) } return userInfos[0], nil } func (u *userImpl) GetUserSpaceList(ctx context.Context, userID int64) (spaces []*userEntity.Space, err error) { userSpaces, err := u.SpaceRepo.GetSpaceList(ctx, userID) if err != nil { return nil, err } spaceIDs := slices.Transform(userSpaces, func(us *model.SpaceUser) int64 { return us.SpaceID }) spaceModels, err := u.SpaceRepo.GetSpaceByIDs(ctx, spaceIDs) if err != nil { return nil, err } uris := slices.ToMap(spaceModels, func(sm *model.Space) (string, bool) { return sm.IconURI, false }) urls := make(map[string]string, len(uris)) for uri := range uris { url, err := u.IconOSS.GetObjectUrl(ctx, uri) if err != nil { return nil, err } urls[uri] = url } return slices.Transform(spaceModels, func(sm *model.Space) *userEntity.Space { return spacePo2Do(sm, urls[sm.IconURI]) }), nil } func spacePo2Do(space *model.Space, iconUrl string) *userEntity.Space { return &userEntity.Space{ ID: space.ID, Name: space.Name, Description: space.Description, IconURL: iconUrl, SpaceType: userEntity.SpaceTypePersonal, OwnerID: space.OwnerID, CreatorID: space.CreatorID, CreatedAt: space.CreatedAt, UpdatedAt: space.UpdatedAt, } } // Argon2id 参数 type argon2Params struct { memory uint32 iterations uint32 parallelism uint8 saltLength uint32 keyLength uint32 } // 默认的 Argon2id 参数 var defaultArgon2Params = &argon2Params{ memory: 64 * 1024, // 64MB iterations: 3, parallelism: 4, saltLength: 16, keyLength: 32, } // 使用 Argon2id 算法对密码进行哈希处理 func hashPassword(password string) (string, error) { p := defaultArgon2Params // 生成随机盐值 salt := make([]byte, p.saltLength) _, err := rand.Read(salt) if err != nil { return "", err } // 使用 Argon2id 算法计算哈希值 hash := argon2.IDKey( []byte(password), salt, p.iterations, p.memory, p.parallelism, p.keyLength, ) // 编码为 base64 格式 b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Hash := base64.RawStdEncoding.EncodeToString(hash) // 格式:$argon2id$v=19$m=65536,t=3,p=4$$ encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", p.memory, p.iterations, p.parallelism, b64Salt, b64Hash) return encoded, nil } // 验证密码是否匹配 func verifyPassword(password, encodedHash string) (bool, error) { // 解析编码后的哈希字符串 parts := strings.Split(encodedHash, "$") if len(parts) != 6 { return false, fmt.Errorf("invalid hash format") } var p argon2Params _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism) if err != nil { return false, err } salt, err := base64.RawStdEncoding.DecodeString(parts[4]) if err != nil { return false, err } p.saltLength = uint32(len(salt)) decodedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) if err != nil { return false, err } p.keyLength = uint32(len(decodedHash)) // 使用相同的参数和盐值计算哈希值 computedHash := argon2.IDKey( []byte(password), salt, p.iterations, p.memory, p.parallelism, p.keyLength, ) // 比较计算得到的哈希值与存储的哈希值 return subtle.ConstantTimeCompare(decodedHash, computedHash) == 1, nil } // Session 结构体,包含会话信息 type Session struct { ID int64 `json:"id"` // 会话唯一标识符 CreatedAt time.Time `json:"created_at"` // 创建时间 ExpiresAt time.Time `json:"expires_at"` // 过期时间 } // 用于签名的密钥(在实际应用中应从配置中读取或使用环境变量) var hmacSecret = []byte("opencoze-session-hmac-key") // 生成安全的会话密钥 func generateSessionKey(sessionID int64) (string, error) { // 创建默认会话结构(不包含用户ID,将在Login方法中设置) session := Session{ ID: sessionID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(consts.DefaultSessionDuration), } // 序列化会话数据 sessionData, err := json.Marshal(session) if err != nil { return "", err } // 计算HMAC签名以确保完整性 h := hmac.New(sha256.New, hmacSecret) h.Write(sessionData) signature := h.Sum(nil) // 组合会话数据和签名 finalData := append(sessionData, signature...) // Base64编码最终结果 return base64.RawURLEncoding.EncodeToString(finalData), nil } // 验证会话密钥的有效性 func verifySessionKey(sessionKey string) (*Session, error) { // 解码会话数据 data, err := base64.RawURLEncoding.DecodeString(sessionKey) if err != nil { return nil, fmt.Errorf("invalid session format: %w", err) } // 确保数据长够长,至少包含会话数据和签名 if len(data) < 32 { // 简单检查,实际应该更严格 return nil, fmt.Errorf("session data too short") } // 分离会话数据和签名 sessionData := data[:len(data)-32] // 假设签名是32字节 signature := data[len(data)-32:] // 验证签名 h := hmac.New(sha256.New, hmacSecret) h.Write(sessionData) expectedSignature := h.Sum(nil) if !hmac.Equal(signature, expectedSignature) { return nil, fmt.Errorf("invalid session signature") } // 解析会话数据 var session Session if err := json.Unmarshal(sessionData, &session); err != nil { return nil, fmt.Errorf("invalid session data: %w", err) } // 检查会话是否过期 if time.Now().After(session.ExpiresAt) { return nil, fmt.Errorf("session expired") } return &session, nil } func userPo2Do(model *model.User, iconURL string) *userEntity.User { return &userEntity.User{ UserID: model.ID, Name: model.Name, UniqueName: model.UniqueName, Email: model.Email, Description: model.Description, IconURI: model.IconURI, IconURL: iconURL, UserVerified: model.UserVerified, Locale: model.Locale, SessionKey: model.SessionKey, CreatedAt: model.CreatedAt, UpdatedAt: model.UpdatedAt, } }