feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,663 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"unicode/utf8"
"golang.org/x/crypto/argon2"
uploadEntity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
userEntity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
"github.com/coze-dev/coze-studio/backend/domain/user/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/user/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Components struct {
IconOSS storage.Storage
IDGen idgen.IDGenerator
UserRepo repository.UserRepository
SpaceRepo repository.SpaceRepository
}
func NewUserDomain(ctx context.Context, c *Components) User {
return &userImpl{
Components: c,
}
}
type userImpl struct {
*Components
}
func (u *userImpl) Login(ctx context.Context, email, password string) (user *userEntity.User, err error) {
userModel, exist, err := u.UserRepo.GetUsersByEmail(ctx, email)
if err != nil {
return nil, err
}
if !exist {
return nil, errorx.New(errno.ErrUserInfoInvalidateCode)
}
// 验证密码,使用 Argon2id 算法
valid, err := verifyPassword(password, userModel.Password)
if err != nil {
return nil, err
}
if !valid {
return nil, errorx.New(errno.ErrUserInfoInvalidateCode)
}
uniqueSessionID, err := u.IDGen.GenID(ctx)
if err != nil {
return nil, fmt.Errorf("failed to generate session id: %w", err)
}
sessionKey, err := generateSessionKey(uniqueSessionID)
if err != nil {
return nil, err
}
// 更新用户会话密钥
err = u.UserRepo.UpdateSessionKey(ctx, userModel.ID, sessionKey)
if err != nil {
return nil, err
}
userModel.SessionKey = sessionKey
resURL, err := u.IconOSS.GetObjectUrl(ctx, userModel.IconURI)
if err != nil {
return nil, err
}
return userPo2Do(userModel, resURL), nil
}
func (u *userImpl) Logout(ctx context.Context, userID int64) (err error) {
err = u.UserRepo.ClearSessionKey(ctx, userID)
if err != nil {
return err
}
return nil
}
func (u *userImpl) ResetPassword(ctx context.Context, email, password string) (err error) {
// 使用 Argon2id 算法对密码进行哈希处理
hashedPassword, err := hashPassword(password)
if err != nil {
return err
}
err = u.UserRepo.UpdatePassword(ctx, email, hashedPassword)
if err != nil {
return err
}
return nil
}
func (u *userImpl) GetUserInfo(ctx context.Context, userID int64) (resp *userEntity.User, err error) {
if userID <= 0 {
return nil, errorx.New(errno.ErrUserInvalidParamCode,
errorx.KVf("msg", "invalid user id : %d", userID))
}
userModel, err := u.UserRepo.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
resURL, err := u.IconOSS.GetObjectUrl(ctx, userModel.IconURI)
if err != nil {
return nil, err
}
return userPo2Do(userModel, resURL), nil
}
func (u *userImpl) UpdateAvatar(ctx context.Context, userID int64, ext string, imagePayload []byte) (url string, err error) {
avatarKey := "user_avatar/" + strconv.FormatInt(userID, 10) + "." + ext
err = u.IconOSS.PutObject(ctx, avatarKey, imagePayload)
if err != nil {
return "", err
}
err = u.UserRepo.UpdateAvatar(ctx, userID, avatarKey)
if err != nil {
return "", err
}
url, err = u.IconOSS.GetObjectUrl(ctx, avatarKey)
if err != nil {
return "", err
}
return url, nil
}
func (u *userImpl) ValidateProfileUpdate(ctx context.Context, req *ValidateProfileUpdateRequest) (
resp *ValidateProfileUpdateResponse, err error,
) {
if req.UniqueName == nil && req.Email == nil {
return nil, errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", "missing parameter"))
}
if req.UniqueName != nil {
uniqueName := ptr.From(req.UniqueName)
charNum := utf8.RuneCountInString(uniqueName)
if charNum < 4 || charNum > 20 {
return &ValidateProfileUpdateResponse{
Code: UniqueNameTooShortOrTooLong,
Msg: "unique name length should be between 4 and 20",
}, nil
}
exist, err := u.UserRepo.CheckUniqueNameExist(ctx, uniqueName)
if err != nil {
return nil, err
}
if exist {
return &ValidateProfileUpdateResponse{
Code: UniqueNameExist,
Msg: "unique name existed",
}, nil
}
}
return &ValidateProfileUpdateResponse{
Code: ValidateSuccess,
Msg: "success",
}, nil
}
func (u *userImpl) UpdateProfile(ctx context.Context, req *UpdateProfileRequest) error {
updates := map[string]interface{}{
"updated_at": time.Now().UnixMilli(),
}
if req.UniqueName != nil {
resp, err := u.ValidateProfileUpdate(ctx, &ValidateProfileUpdateRequest{
UniqueName: req.UniqueName,
})
if err != nil {
return err
}
if resp.Code != ValidateSuccess {
return errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", resp.Msg))
}
updates["unique_name"] = ptr.From(req.UniqueName)
}
if req.Name != nil {
updates["name"] = ptr.From(req.Name)
}
if req.Description != nil {
updates["description"] = ptr.From(req.Description)
}
if req.Locale != nil {
updates["locale"] = ptr.From(req.Locale)
}
err := u.UserRepo.UpdateProfile(ctx, req.UserID, updates)
if err != nil {
return err
}
return nil
}
func (u *userImpl) Create(ctx context.Context, req *CreateUserRequest) (user *userEntity.User, err error) {
exist, err := u.UserRepo.CheckEmailExist(ctx, req.Email)
if err != nil {
return nil, err
}
if exist {
return nil, errorx.New(errno.ErrUserEmailAlreadyExistCode, errorx.KV("email", req.Email))
}
if req.UniqueName != "" {
exist, err = u.UserRepo.CheckUniqueNameExist(ctx, req.UniqueName)
if err != nil {
return nil, err
}
if exist {
return nil, errorx.New(errno.ErrUserUniqueNameAlreadyExistCode, errorx.KV("name", req.UniqueName))
}
}
// 使用 Argon2id 算法对密码进行哈希处理
hashedPassword, err := hashPassword(req.Password)
if err != nil {
return nil, err
}
name := req.Name
if name == "" {
name = strings.Split(req.Email, "@")[0]
}
userID, err := u.IDGen.GenID(ctx)
if err != nil {
return nil, fmt.Errorf("generate id error: %w", err)
}
now := time.Now().UnixMilli()
spaceID := req.SpaceID
if spaceID <= 0 {
var sid int64
sid, err = u.IDGen.GenID(ctx)
if err != nil {
return nil, fmt.Errorf("gen space_id failed: %w", err)
}
err = u.SpaceRepo.CreateSpace(ctx, &model.Space{
ID: sid,
Name: "Personal Space",
Description: "This is your personal space",
IconURI: uploadEntity.EnterpriseIconURI,
OwnerID: userID,
CreatorID: userID,
CreatedAt: now,
UpdatedAt: now,
})
if err != nil {
return nil, fmt.Errorf("create personal space failed: %w", err)
}
spaceID = sid
}
newUser := &model.User{
ID: userID,
IconURI: uploadEntity.UserIconURI,
Name: name,
UniqueName: u.getUniqueNameFormEmail(ctx, req.Email),
Email: req.Email,
Password: hashedPassword,
Description: req.Description,
UserVerified: false,
Locale: req.Locale,
CreatedAt: now,
UpdatedAt: now,
}
err = u.UserRepo.CreateUser(ctx, newUser)
if err != nil {
return nil, fmt.Errorf("insert user failed: %w", err)
}
err = u.SpaceRepo.AddSpaceUser(ctx, &model.SpaceUser{
SpaceID: spaceID,
UserID: userID,
RoleType: 1,
CreatedAt: now,
UpdatedAt: now,
})
if err != nil {
return nil, fmt.Errorf("add space user failed: %w", err)
}
iconURL, err := u.IconOSS.GetObjectUrl(ctx, newUser.IconURI)
if err != nil {
return nil, fmt.Errorf("get icon url failed: %w", err)
}
return userPo2Do(newUser, iconURL), nil
}
func (u *userImpl) getUniqueNameFormEmail(ctx context.Context, email string) string {
arr := strings.Split(email, "@")
if len(arr) != 2 {
return email
}
username := arr[0]
exist, err := u.UserRepo.CheckUniqueNameExist(ctx, username)
if err != nil {
logs.CtxWarnf(ctx, "check unique name exist failed: %v", err)
return email
}
if exist {
logs.CtxWarnf(ctx, "unique name %s already exist", username)
return email
}
return username
}
func (u *userImpl) ValidateSession(ctx context.Context, sessionKey string) (
session *userEntity.Session, exist bool, err error,
) {
// 验证会话密钥
sessionModel, err := verifySessionKey(sessionKey)
if err != nil {
return nil, false, errorx.New(errno.ErrUserAuthenticationFailed, errorx.KV("reason", "access denied"))
}
// 从数据库获取用户信息
userModel, exist, err := u.UserRepo.GetUserBySessionKey(ctx, sessionKey)
if err != nil {
return nil, false, err
}
if !exist {
return nil, false, nil
}
return &userEntity.Session{
UserID: userModel.ID,
Locale: userModel.Locale,
CreatedAt: sessionModel.CreatedAt,
ExpiresAt: sessionModel.ExpiresAt,
}, true, nil
}
func (u *userImpl) MGetUserProfiles(ctx context.Context, userIDs []int64) (users []*userEntity.User, err error) {
userModels, err := u.UserRepo.GetUsersByIDs(ctx, userIDs)
if err != nil {
return nil, err
}
users = make([]*userEntity.User, 0, len(userModels))
for _, um := range userModels {
// 获取图片URL
resURL, err := u.IconOSS.GetObjectUrl(ctx, um.IconURI)
if err != nil {
continue // 如果获取图片URL失败跳过该用户
}
users = append(users, userPo2Do(um, resURL))
}
return users, nil
}
func (u *userImpl) GetUserProfiles(ctx context.Context, userID int64) (user *userEntity.User, err error) {
userInfos, err := u.MGetUserProfiles(ctx, []int64{userID})
if err != nil {
return nil, err
}
if len(userInfos) == 0 {
return nil, errorx.New(errno.ErrUserResourceNotFound, errorx.KV("type", "user"),
errorx.KV("id", conv.Int64ToStr(userID)))
}
return userInfos[0], nil
}
func (u *userImpl) GetUserSpaceList(ctx context.Context, userID int64) (spaces []*userEntity.Space, err error) {
userSpaces, err := u.SpaceRepo.GetSpaceList(ctx, userID)
if err != nil {
return nil, err
}
spaceIDs := slices.Transform(userSpaces, func(us *model.SpaceUser) int64 {
return us.SpaceID
})
spaceModels, err := u.SpaceRepo.GetSpaceByIDs(ctx, spaceIDs)
if err != nil {
return nil, err
}
uris := slices.ToMap(spaceModels, func(sm *model.Space) (string, bool) {
return sm.IconURI, false
})
urls := make(map[string]string, len(uris))
for uri := range uris {
url, err := u.IconOSS.GetObjectUrl(ctx, uri)
if err != nil {
return nil, err
}
urls[uri] = url
}
return slices.Transform(spaceModels, func(sm *model.Space) *userEntity.Space {
return spacePo2Do(sm, urls[sm.IconURI])
}), nil
}
func spacePo2Do(space *model.Space, iconUrl string) *userEntity.Space {
return &userEntity.Space{
ID: space.ID,
Name: space.Name,
Description: space.Description,
IconURL: iconUrl,
SpaceType: userEntity.SpaceTypePersonal,
OwnerID: space.OwnerID,
CreatorID: space.CreatorID,
CreatedAt: space.CreatedAt,
UpdatedAt: space.UpdatedAt,
}
}
// Argon2id 参数
type argon2Params struct {
memory uint32
iterations uint32
parallelism uint8
saltLength uint32
keyLength uint32
}
// 默认的 Argon2id 参数
var defaultArgon2Params = &argon2Params{
memory: 64 * 1024, // 64MB
iterations: 3,
parallelism: 4,
saltLength: 16,
keyLength: 32,
}
// 使用 Argon2id 算法对密码进行哈希处理
func hashPassword(password string) (string, error) {
p := defaultArgon2Params
// 生成随机盐值
salt := make([]byte, p.saltLength)
_, err := rand.Read(salt)
if err != nil {
return "", err
}
// 使用 Argon2id 算法计算哈希值
hash := argon2.IDKey(
[]byte(password),
salt,
p.iterations,
p.memory,
p.parallelism,
p.keyLength,
)
// 编码为 base64 格式
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
// 格式:$argon2id$v=19$m=65536,t=3,p=4$<salt>$<hash>
encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
p.memory, p.iterations, p.parallelism, b64Salt, b64Hash)
return encoded, nil
}
// 验证密码是否匹配
func verifyPassword(password, encodedHash string) (bool, error) {
// 解析编码后的哈希字符串
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
var p argon2Params
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
if err != nil {
return false, err
}
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false, err
}
p.saltLength = uint32(len(salt))
decodedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false, err
}
p.keyLength = uint32(len(decodedHash))
// 使用相同的参数和盐值计算哈希值
computedHash := argon2.IDKey(
[]byte(password),
salt,
p.iterations,
p.memory,
p.parallelism,
p.keyLength,
)
// 比较计算得到的哈希值与存储的哈希值
return subtle.ConstantTimeCompare(decodedHash, computedHash) == 1, nil
}
// Session 结构体,包含会话信息
type Session struct {
ID int64 `json:"id"` // 会话唯一标识符
CreatedAt time.Time `json:"created_at"` // 创建时间
ExpiresAt time.Time `json:"expires_at"` // 过期时间
}
// 用于签名的密钥(在实际应用中应从配置中读取或使用环境变量)
var hmacSecret = []byte("opencoze-session-hmac-key")
// 生成安全的会话密钥
func generateSessionKey(sessionID int64) (string, error) {
// 创建默认会话结构不包含用户ID将在Login方法中设置
session := Session{
ID: sessionID,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(consts.DefaultSessionDuration),
}
// 序列化会话数据
sessionData, err := json.Marshal(session)
if err != nil {
return "", err
}
// 计算HMAC签名以确保完整性
h := hmac.New(sha256.New, hmacSecret)
h.Write(sessionData)
signature := h.Sum(nil)
// 组合会话数据和签名
finalData := append(sessionData, signature...)
// Base64编码最终结果
return base64.RawURLEncoding.EncodeToString(finalData), nil
}
// 验证会话密钥的有效性
func verifySessionKey(sessionKey string) (*Session, error) {
// 解码会话数据
data, err := base64.RawURLEncoding.DecodeString(sessionKey)
if err != nil {
return nil, fmt.Errorf("invalid session format: %w", err)
}
// 确保数据长够长,至少包含会话数据和签名
if len(data) < 32 { // 简单检查,实际应该更严格
return nil, fmt.Errorf("session data too short")
}
// 分离会话数据和签名
sessionData := data[:len(data)-32] // 假设签名是32字节
signature := data[len(data)-32:]
// 验证签名
h := hmac.New(sha256.New, hmacSecret)
h.Write(sessionData)
expectedSignature := h.Sum(nil)
if !hmac.Equal(signature, expectedSignature) {
return nil, fmt.Errorf("invalid session signature")
}
// 解析会话数据
var session Session
if err := json.Unmarshal(sessionData, &session); err != nil {
return nil, fmt.Errorf("invalid session data: %w", err)
}
// 检查会话是否过期
if time.Now().After(session.ExpiresAt) {
return nil, fmt.Errorf("session expired")
}
return &session, nil
}
func userPo2Do(model *model.User, iconURL string) *userEntity.User {
return &userEntity.User{
UserID: model.ID,
Name: model.Name,
UniqueName: model.UniqueName,
Email: model.Email,
Description: model.Description,
IconURI: model.IconURI,
IconURL: iconURL,
UserVerified: model.UserVerified,
Locale: model.Locale,
SessionKey: model.SessionKey,
CreatedAt: model.CreatedAt,
UpdatedAt: model.UpdatedAt,
}
}