coze-studio/backend/domain/user/service/user_impl.go

664 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"unicode/utf8"
"golang.org/x/crypto/argon2"
uploadEntity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
userEntity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
"github.com/coze-dev/coze-studio/backend/domain/user/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/user/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Components struct {
IconOSS storage.Storage
IDGen idgen.IDGenerator
UserRepo repository.UserRepository
SpaceRepo repository.SpaceRepository
}
func NewUserDomain(ctx context.Context, c *Components) User {
return &userImpl{
Components: c,
}
}
type userImpl struct {
*Components
}
func (u *userImpl) Login(ctx context.Context, email, password string) (user *userEntity.User, err error) {
userModel, exist, err := u.UserRepo.GetUsersByEmail(ctx, email)
if err != nil {
return nil, err
}
if !exist {
return nil, errorx.New(errno.ErrUserInfoInvalidateCode)
}
// 验证密码,使用 Argon2id 算法
valid, err := verifyPassword(password, userModel.Password)
if err != nil {
return nil, err
}
if !valid {
return nil, errorx.New(errno.ErrUserInfoInvalidateCode)
}
uniqueSessionID, err := u.IDGen.GenID(ctx)
if err != nil {
return nil, fmt.Errorf("failed to generate session id: %w", err)
}
sessionKey, err := generateSessionKey(uniqueSessionID)
if err != nil {
return nil, err
}
// 更新用户会话密钥
err = u.UserRepo.UpdateSessionKey(ctx, userModel.ID, sessionKey)
if err != nil {
return nil, err
}
userModel.SessionKey = sessionKey
resURL, err := u.IconOSS.GetObjectUrl(ctx, userModel.IconURI)
if err != nil {
return nil, err
}
return userPo2Do(userModel, resURL), nil
}
func (u *userImpl) Logout(ctx context.Context, userID int64) (err error) {
err = u.UserRepo.ClearSessionKey(ctx, userID)
if err != nil {
return err
}
return nil
}
func (u *userImpl) ResetPassword(ctx context.Context, email, password string) (err error) {
// 使用 Argon2id 算法对密码进行哈希处理
hashedPassword, err := hashPassword(password)
if err != nil {
return err
}
err = u.UserRepo.UpdatePassword(ctx, email, hashedPassword)
if err != nil {
return err
}
return nil
}
func (u *userImpl) GetUserInfo(ctx context.Context, userID int64) (resp *userEntity.User, err error) {
if userID <= 0 {
return nil, errorx.New(errno.ErrUserInvalidParamCode,
errorx.KVf("msg", "invalid user id : %d", userID))
}
userModel, err := u.UserRepo.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
resURL, err := u.IconOSS.GetObjectUrl(ctx, userModel.IconURI)
if err != nil {
return nil, err
}
return userPo2Do(userModel, resURL), nil
}
func (u *userImpl) UpdateAvatar(ctx context.Context, userID int64, ext string, imagePayload []byte) (url string, err error) {
avatarKey := "user_avatar/" + strconv.FormatInt(userID, 10) + "." + ext
err = u.IconOSS.PutObject(ctx, avatarKey, imagePayload)
if err != nil {
return "", err
}
err = u.UserRepo.UpdateAvatar(ctx, userID, avatarKey)
if err != nil {
return "", err
}
url, err = u.IconOSS.GetObjectUrl(ctx, avatarKey)
if err != nil {
return "", err
}
return url, nil
}
func (u *userImpl) ValidateProfileUpdate(ctx context.Context, req *ValidateProfileUpdateRequest) (
resp *ValidateProfileUpdateResponse, err error,
) {
if req.UniqueName == nil && req.Email == nil {
return nil, errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", "missing parameter"))
}
if req.UniqueName != nil {
uniqueName := ptr.From(req.UniqueName)
charNum := utf8.RuneCountInString(uniqueName)
if charNum < 4 || charNum > 20 {
return &ValidateProfileUpdateResponse{
Code: UniqueNameTooShortOrTooLong,
Msg: "unique name length should be between 4 and 20",
}, nil
}
exist, err := u.UserRepo.CheckUniqueNameExist(ctx, uniqueName)
if err != nil {
return nil, err
}
if exist {
return &ValidateProfileUpdateResponse{
Code: UniqueNameExist,
Msg: "unique name existed",
}, nil
}
}
return &ValidateProfileUpdateResponse{
Code: ValidateSuccess,
Msg: "success",
}, nil
}
func (u *userImpl) UpdateProfile(ctx context.Context, req *UpdateProfileRequest) error {
updates := map[string]interface{}{
"updated_at": time.Now().UnixMilli(),
}
if req.UniqueName != nil {
resp, err := u.ValidateProfileUpdate(ctx, &ValidateProfileUpdateRequest{
UniqueName: req.UniqueName,
})
if err != nil {
return err
}
if resp.Code != ValidateSuccess {
return errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", resp.Msg))
}
updates["unique_name"] = ptr.From(req.UniqueName)
}
if req.Name != nil {
updates["name"] = ptr.From(req.Name)
}
if req.Description != nil {
updates["description"] = ptr.From(req.Description)
}
if req.Locale != nil {
updates["locale"] = ptr.From(req.Locale)
}
err := u.UserRepo.UpdateProfile(ctx, req.UserID, updates)
if err != nil {
return err
}
return nil
}
func (u *userImpl) Create(ctx context.Context, req *CreateUserRequest) (user *userEntity.User, err error) {
exist, err := u.UserRepo.CheckEmailExist(ctx, req.Email)
if err != nil {
return nil, err
}
if exist {
return nil, errorx.New(errno.ErrUserEmailAlreadyExistCode, errorx.KV("email", req.Email))
}
if req.UniqueName != "" {
exist, err = u.UserRepo.CheckUniqueNameExist(ctx, req.UniqueName)
if err != nil {
return nil, err
}
if exist {
return nil, errorx.New(errno.ErrUserUniqueNameAlreadyExistCode, errorx.KV("name", req.UniqueName))
}
}
// 使用 Argon2id 算法对密码进行哈希处理
hashedPassword, err := hashPassword(req.Password)
if err != nil {
return nil, err
}
name := req.Name
if name == "" {
name = strings.Split(req.Email, "@")[0]
}
userID, err := u.IDGen.GenID(ctx)
if err != nil {
return nil, fmt.Errorf("generate id error: %w", err)
}
now := time.Now().UnixMilli()
spaceID := req.SpaceID
if spaceID <= 0 {
var sid int64
sid, err = u.IDGen.GenID(ctx)
if err != nil {
return nil, fmt.Errorf("gen space_id failed: %w", err)
}
err = u.SpaceRepo.CreateSpace(ctx, &model.Space{
ID: sid,
Name: "Personal Space",
Description: "This is your personal space",
IconURI: uploadEntity.EnterpriseIconURI,
OwnerID: userID,
CreatorID: userID,
CreatedAt: now,
UpdatedAt: now,
})
if err != nil {
return nil, fmt.Errorf("create personal space failed: %w", err)
}
spaceID = sid
}
newUser := &model.User{
ID: userID,
IconURI: uploadEntity.UserIconURI,
Name: name,
UniqueName: u.getUniqueNameFormEmail(ctx, req.Email),
Email: req.Email,
Password: hashedPassword,
Description: req.Description,
UserVerified: false,
Locale: req.Locale,
CreatedAt: now,
UpdatedAt: now,
}
err = u.UserRepo.CreateUser(ctx, newUser)
if err != nil {
return nil, fmt.Errorf("insert user failed: %w", err)
}
err = u.SpaceRepo.AddSpaceUser(ctx, &model.SpaceUser{
SpaceID: spaceID,
UserID: userID,
RoleType: 1,
CreatedAt: now,
UpdatedAt: now,
})
if err != nil {
return nil, fmt.Errorf("add space user failed: %w", err)
}
iconURL, err := u.IconOSS.GetObjectUrl(ctx, newUser.IconURI)
if err != nil {
return nil, fmt.Errorf("get icon url failed: %w", err)
}
return userPo2Do(newUser, iconURL), nil
}
func (u *userImpl) getUniqueNameFormEmail(ctx context.Context, email string) string {
arr := strings.Split(email, "@")
if len(arr) != 2 {
return email
}
username := arr[0]
exist, err := u.UserRepo.CheckUniqueNameExist(ctx, username)
if err != nil {
logs.CtxWarnf(ctx, "check unique name exist failed: %v", err)
return email
}
if exist {
logs.CtxWarnf(ctx, "unique name %s already exist", username)
return email
}
return username
}
func (u *userImpl) ValidateSession(ctx context.Context, sessionKey string) (
session *userEntity.Session, exist bool, err error,
) {
// 验证会话密钥
sessionModel, err := verifySessionKey(sessionKey)
if err != nil {
return nil, false, errorx.New(errno.ErrUserAuthenticationFailed, errorx.KV("reason", "access denied"))
}
// 从数据库获取用户信息
userModel, exist, err := u.UserRepo.GetUserBySessionKey(ctx, sessionKey)
if err != nil {
return nil, false, err
}
if !exist {
return nil, false, nil
}
return &userEntity.Session{
UserID: userModel.ID,
Locale: userModel.Locale,
CreatedAt: sessionModel.CreatedAt,
ExpiresAt: sessionModel.ExpiresAt,
}, true, nil
}
func (u *userImpl) MGetUserProfiles(ctx context.Context, userIDs []int64) (users []*userEntity.User, err error) {
userModels, err := u.UserRepo.GetUsersByIDs(ctx, userIDs)
if err != nil {
return nil, err
}
users = make([]*userEntity.User, 0, len(userModels))
for _, um := range userModels {
// 获取图片URL
resURL, err := u.IconOSS.GetObjectUrl(ctx, um.IconURI)
if err != nil {
continue // 如果获取图片URL失败跳过该用户
}
users = append(users, userPo2Do(um, resURL))
}
return users, nil
}
func (u *userImpl) GetUserProfiles(ctx context.Context, userID int64) (user *userEntity.User, err error) {
userInfos, err := u.MGetUserProfiles(ctx, []int64{userID})
if err != nil {
return nil, err
}
if len(userInfos) == 0 {
return nil, errorx.New(errno.ErrUserResourceNotFound, errorx.KV("type", "user"),
errorx.KV("id", conv.Int64ToStr(userID)))
}
return userInfos[0], nil
}
func (u *userImpl) GetUserSpaceList(ctx context.Context, userID int64) (spaces []*userEntity.Space, err error) {
userSpaces, err := u.SpaceRepo.GetSpaceList(ctx, userID)
if err != nil {
return nil, err
}
spaceIDs := slices.Transform(userSpaces, func(us *model.SpaceUser) int64 {
return us.SpaceID
})
spaceModels, err := u.SpaceRepo.GetSpaceByIDs(ctx, spaceIDs)
if err != nil {
return nil, err
}
uris := slices.ToMap(spaceModels, func(sm *model.Space) (string, bool) {
return sm.IconURI, false
})
urls := make(map[string]string, len(uris))
for uri := range uris {
url, err := u.IconOSS.GetObjectUrl(ctx, uri)
if err != nil {
return nil, err
}
urls[uri] = url
}
return slices.Transform(spaceModels, func(sm *model.Space) *userEntity.Space {
return spacePo2Do(sm, urls[sm.IconURI])
}), nil
}
func spacePo2Do(space *model.Space, iconUrl string) *userEntity.Space {
return &userEntity.Space{
ID: space.ID,
Name: space.Name,
Description: space.Description,
IconURL: iconUrl,
SpaceType: userEntity.SpaceTypePersonal,
OwnerID: space.OwnerID,
CreatorID: space.CreatorID,
CreatedAt: space.CreatedAt,
UpdatedAt: space.UpdatedAt,
}
}
// Argon2id 参数
type argon2Params struct {
memory uint32
iterations uint32
parallelism uint8
saltLength uint32
keyLength uint32
}
// 默认的 Argon2id 参数
var defaultArgon2Params = &argon2Params{
memory: 64 * 1024, // 64MB
iterations: 3,
parallelism: 4,
saltLength: 16,
keyLength: 32,
}
// 使用 Argon2id 算法对密码进行哈希处理
func hashPassword(password string) (string, error) {
p := defaultArgon2Params
// 生成随机盐值
salt := make([]byte, p.saltLength)
_, err := rand.Read(salt)
if err != nil {
return "", err
}
// 使用 Argon2id 算法计算哈希值
hash := argon2.IDKey(
[]byte(password),
salt,
p.iterations,
p.memory,
p.parallelism,
p.keyLength,
)
// 编码为 base64 格式
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
// 格式:$argon2id$v=19$m=65536,t=3,p=4$<salt>$<hash>
encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
p.memory, p.iterations, p.parallelism, b64Salt, b64Hash)
return encoded, nil
}
// 验证密码是否匹配
func verifyPassword(password, encodedHash string) (bool, error) {
// 解析编码后的哈希字符串
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
var p argon2Params
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
if err != nil {
return false, err
}
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false, err
}
p.saltLength = uint32(len(salt))
decodedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false, err
}
p.keyLength = uint32(len(decodedHash))
// 使用相同的参数和盐值计算哈希值
computedHash := argon2.IDKey(
[]byte(password),
salt,
p.iterations,
p.memory,
p.parallelism,
p.keyLength,
)
// 比较计算得到的哈希值与存储的哈希值
return subtle.ConstantTimeCompare(decodedHash, computedHash) == 1, nil
}
// Session 结构体,包含会话信息
type Session struct {
ID int64 `json:"id"` // 会话唯一标识符
CreatedAt time.Time `json:"created_at"` // 创建时间
ExpiresAt time.Time `json:"expires_at"` // 过期时间
}
// 用于签名的密钥(在实际应用中应从配置中读取或使用环境变量)
var hmacSecret = []byte("opencoze-session-hmac-key")
// 生成安全的会话密钥
func generateSessionKey(sessionID int64) (string, error) {
// 创建默认会话结构不包含用户ID将在Login方法中设置
session := Session{
ID: sessionID,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(consts.DefaultSessionDuration),
}
// 序列化会话数据
sessionData, err := json.Marshal(session)
if err != nil {
return "", err
}
// 计算HMAC签名以确保完整性
h := hmac.New(sha256.New, hmacSecret)
h.Write(sessionData)
signature := h.Sum(nil)
// 组合会话数据和签名
finalData := append(sessionData, signature...)
// Base64编码最终结果
return base64.RawURLEncoding.EncodeToString(finalData), nil
}
// 验证会话密钥的有效性
func verifySessionKey(sessionKey string) (*Session, error) {
// 解码会话数据
data, err := base64.RawURLEncoding.DecodeString(sessionKey)
if err != nil {
return nil, fmt.Errorf("invalid session format: %w", err)
}
// 确保数据长够长,至少包含会话数据和签名
if len(data) < 32 { // 简单检查,实际应该更严格
return nil, fmt.Errorf("session data too short")
}
// 分离会话数据和签名
sessionData := data[:len(data)-32] // 假设签名是32字节
signature := data[len(data)-32:]
// 验证签名
h := hmac.New(sha256.New, hmacSecret)
h.Write(sessionData)
expectedSignature := h.Sum(nil)
if !hmac.Equal(signature, expectedSignature) {
return nil, fmt.Errorf("invalid session signature")
}
// 解析会话数据
var session Session
if err := json.Unmarshal(sessionData, &session); err != nil {
return nil, fmt.Errorf("invalid session data: %w", err)
}
// 检查会话是否过期
if time.Now().After(session.ExpiresAt) {
return nil, fmt.Errorf("session expired")
}
return &session, nil
}
func userPo2Do(model *model.User, iconURL string) *userEntity.User {
return &userEntity.User{
UserID: model.ID,
Name: model.Name,
UniqueName: model.UniqueName,
Email: model.Email,
Description: model.Description,
IconURI: model.IconURI,
IconURL: iconURL,
UserVerified: model.UserVerified,
Locale: model.Locale,
SessionKey: model.SessionKey,
CreatedAt: model.CreatedAt,
UpdatedAt: model.UpdatedAt,
}
}