331 lines
7.9 KiB
Go
331 lines
7.9 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package dal
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
|
)
|
|
|
|
func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAuthDAO {
|
|
return &PluginOAuthAuthDAO{
|
|
idGen: idGen,
|
|
query: query.Use(db),
|
|
}
|
|
}
|
|
|
|
type pluginOAuthAuthPO model.PluginOauthAuth
|
|
|
|
func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo {
|
|
secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
|
|
if secret == "" {
|
|
secret = encrypt.DefaultOAuthTokenSecret
|
|
}
|
|
|
|
if p.RefreshToken != "" {
|
|
refreshToken, err := encrypt.DecryptByAES(p.RefreshToken, secret)
|
|
if err == nil {
|
|
p.RefreshToken = string(refreshToken)
|
|
}
|
|
}
|
|
if p.AccessToken != "" {
|
|
accessToken, err := encrypt.DecryptByAES(p.AccessToken, secret)
|
|
if err == nil {
|
|
p.AccessToken = string(accessToken)
|
|
}
|
|
}
|
|
|
|
return &entity.AuthorizationCodeInfo{
|
|
RecordID: p.ID,
|
|
Meta: &entity.AuthorizationCodeMeta{
|
|
UserID: p.UserID,
|
|
PluginID: p.PluginID,
|
|
IsDraft: p.IsDraft,
|
|
},
|
|
Config: p.OauthConfig,
|
|
AccessToken: p.AccessToken,
|
|
RefreshToken: p.RefreshToken,
|
|
TokenExpiredAtMS: p.TokenExpiredAt,
|
|
NextTokenRefreshAtMS: &p.NextTokenRefreshAt,
|
|
LastActiveAtMS: p.LastActiveAt,
|
|
}
|
|
}
|
|
|
|
type PluginOAuthAuthDAO struct {
|
|
idGen idgen.IDGenerator
|
|
query *query.Query
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) Get(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error) {
|
|
table := p.query.PluginOauthAuth
|
|
res, err := table.WithContext(ctx).
|
|
Where(
|
|
table.UserID.Eq(meta.UserID),
|
|
table.PluginID.Eq(meta.PluginID),
|
|
table.IsDraft.Is(meta.IsDraft),
|
|
).
|
|
First()
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, false, nil
|
|
}
|
|
return nil, false, err
|
|
}
|
|
|
|
info = pluginOAuthAuthPO(*res).ToDO()
|
|
|
|
return info, true, nil
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) {
|
|
if info.Meta == nil || info.Meta.UserID == "" || info.Meta.PluginID <= 0 {
|
|
return fmt.Errorf("meta info is required")
|
|
}
|
|
|
|
meta := info.Meta
|
|
secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
|
|
if secret == "" {
|
|
secret = encrypt.DefaultOAuthTokenSecret
|
|
}
|
|
|
|
var accessToken, refreshToken string
|
|
if info.AccessToken != "" {
|
|
accessToken, err = encrypt.EncryptByAES([]byte(info.AccessToken), secret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if info.RefreshToken != "" {
|
|
refreshToken, err = encrypt.EncryptByAES([]byte(info.RefreshToken), secret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
table := p.query.PluginOauthAuth
|
|
_, err = table.WithContext(ctx).
|
|
Select(table.ID).
|
|
Where(
|
|
table.UserID.Eq(meta.UserID),
|
|
table.PluginID.Eq(meta.PluginID),
|
|
table.IsDraft.Is(meta.IsDraft),
|
|
).First()
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
|
|
id, err := p.idGen.GenID(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
po := &model.PluginOauthAuth{
|
|
ID: id,
|
|
UserID: meta.UserID,
|
|
PluginID: meta.PluginID,
|
|
IsDraft: meta.IsDraft,
|
|
AccessToken: accessToken,
|
|
RefreshToken: refreshToken,
|
|
TokenExpiredAt: info.TokenExpiredAtMS,
|
|
NextTokenRefreshAt: info.GetNextTokenRefreshAtMS(),
|
|
OauthConfig: info.Config,
|
|
LastActiveAt: info.LastActiveAtMS,
|
|
}
|
|
|
|
return table.WithContext(ctx).Create(po)
|
|
}
|
|
|
|
updateMap := map[string]any{}
|
|
if accessToken != "" {
|
|
updateMap[table.AccessToken.ColumnName().String()] = accessToken
|
|
}
|
|
if refreshToken != "" {
|
|
updateMap[table.RefreshToken.ColumnName().String()] = refreshToken
|
|
}
|
|
if info.NextTokenRefreshAtMS != nil {
|
|
updateMap[table.NextTokenRefreshAt.ColumnName().String()] = *info.NextTokenRefreshAtMS
|
|
}
|
|
if info.TokenExpiredAtMS > 0 {
|
|
updateMap[table.TokenExpiredAt.ColumnName().String()] = info.TokenExpiredAtMS
|
|
}
|
|
if info.LastActiveAtMS > 0 {
|
|
updateMap[table.LastActiveAt.ColumnName().String()] = info.LastActiveAtMS
|
|
}
|
|
if info.Config != nil {
|
|
b, err := json.Marshal(info.Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
updateMap[table.OauthConfig.ColumnName().String()] = b
|
|
}
|
|
|
|
_, err = table.WithContext(ctx).
|
|
Where(
|
|
table.UserID.Eq(meta.UserID),
|
|
table.PluginID.Eq(meta.PluginID),
|
|
table.IsDraft.Is(meta.IsDraft),
|
|
).
|
|
Updates(updateMap)
|
|
|
|
return err
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) UpdateLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) {
|
|
po := &model.PluginOauthAuth{
|
|
LastActiveAt: lastActiveAtMs,
|
|
}
|
|
|
|
table := p.query.PluginOauthAuth
|
|
_, err = table.WithContext(ctx).
|
|
Where(
|
|
table.UserID.Eq(meta.UserID),
|
|
table.PluginID.Eq(meta.PluginID),
|
|
table.IsDraft.Is(meta.IsDraft),
|
|
).
|
|
Updates(po)
|
|
|
|
return err
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) GetRefreshTokenList(ctx context.Context, nextRefreshAt int64, limit int) (infos []*entity.AuthorizationCodeInfo, err error) {
|
|
const size = 50
|
|
table := p.query.PluginOauthAuth
|
|
|
|
infos = make([]*entity.AuthorizationCodeInfo, 0, limit)
|
|
|
|
for limit > 0 {
|
|
res, err := table.WithContext(ctx).
|
|
Where(
|
|
table.NextTokenRefreshAt.Gt(0),
|
|
table.NextTokenRefreshAt.Lt(nextRefreshAt),
|
|
).
|
|
Order(table.NextTokenRefreshAt.Asc()).
|
|
Limit(size).
|
|
Find()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
infos = make([]*entity.AuthorizationCodeInfo, 0, len(res))
|
|
for _, v := range res {
|
|
infos = append(infos, pluginOAuthAuthPO(*v).ToDO())
|
|
}
|
|
|
|
limit -= size
|
|
|
|
if len(res) < size {
|
|
break
|
|
}
|
|
}
|
|
|
|
return infos, nil
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) BatchDeleteByIDs(ctx context.Context, ids []int64) (err error) {
|
|
table := p.query.PluginOauthAuth
|
|
|
|
chunks := slices.Chunks(ids, 20)
|
|
|
|
for _, chunk := range chunks {
|
|
_, err = table.WithContext(ctx).
|
|
Where(table.ID.In(chunk...)).
|
|
Delete()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) Delete(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) {
|
|
table := p.query.PluginOauthAuth
|
|
_, err = table.WithContext(ctx).
|
|
Where(
|
|
table.UserID.Eq(meta.UserID),
|
|
table.PluginID.Eq(meta.PluginID),
|
|
table.IsDraft.Is(meta.IsDraft),
|
|
).
|
|
Delete()
|
|
return err
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) DeleteExpiredTokens(ctx context.Context, expireAt int64, limit int) (err error) {
|
|
const size = 50
|
|
table := p.query.PluginOauthAuth
|
|
|
|
for limit > 0 {
|
|
res, err := table.WithContext(ctx).
|
|
Where(
|
|
table.TokenExpiredAt.Gt(0),
|
|
table.TokenExpiredAt.Lt(expireAt),
|
|
).
|
|
Limit(size).
|
|
Delete()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
limit -= size
|
|
|
|
if res.RowsAffected < size {
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *PluginOAuthAuthDAO) DeleteInactiveTokens(ctx context.Context, lastActiveAt int64, limit int) (err error) {
|
|
const size = 50
|
|
table := p.query.PluginOauthAuth
|
|
|
|
for limit > 0 {
|
|
res, err := table.WithContext(ctx).
|
|
Where(
|
|
table.LastActiveAt.Gt(0),
|
|
table.LastActiveAt.Lt(lastActiveAt),
|
|
).
|
|
Limit(size).
|
|
Delete()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
limit -= size
|
|
|
|
if res.RowsAffected < size {
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|