coze-studio/backend/domain/plugin/internal/dal/plugin_oauth_auth.go

331 lines
7.9 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dal
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAuthDAO {
return &PluginOAuthAuthDAO{
idGen: idGen,
query: query.Use(db),
}
}
type pluginOAuthAuthPO model.PluginOauthAuth
func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo {
secret := os.Getenv(utils.OAuthTokenSecretEnv)
if secret == "" {
secret = utils.DefaultOAuthTokenSecret
}
if p.RefreshToken != "" {
refreshToken, err := utils.DecryptByAES(p.RefreshToken, secret)
if err == nil {
p.RefreshToken = string(refreshToken)
}
}
if p.AccessToken != "" {
accessToken, err := utils.DecryptByAES(p.AccessToken, secret)
if err == nil {
p.AccessToken = string(accessToken)
}
}
return &entity.AuthorizationCodeInfo{
RecordID: p.ID,
Meta: &entity.AuthorizationCodeMeta{
UserID: p.UserID,
PluginID: p.PluginID,
IsDraft: p.IsDraft,
},
Config: p.OauthConfig,
AccessToken: p.AccessToken,
RefreshToken: p.RefreshToken,
TokenExpiredAtMS: p.TokenExpiredAt,
NextTokenRefreshAtMS: &p.NextTokenRefreshAt,
LastActiveAtMS: p.LastActiveAt,
}
}
type PluginOAuthAuthDAO struct {
idGen idgen.IDGenerator
query *query.Query
}
func (p *PluginOAuthAuthDAO) Get(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error) {
table := p.query.PluginOauthAuth
res, err := table.WithContext(ctx).
Where(
table.UserID.Eq(meta.UserID),
table.PluginID.Eq(meta.PluginID),
table.IsDraft.Is(meta.IsDraft),
).
First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, false, nil
}
return nil, false, err
}
info = pluginOAuthAuthPO(*res).ToDO()
return info, true, nil
}
func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) {
if info.Meta == nil || info.Meta.UserID == "" || info.Meta.PluginID <= 0 {
return fmt.Errorf("meta info is required")
}
meta := info.Meta
secret := os.Getenv(utils.OAuthTokenSecretEnv)
if secret == "" {
secret = utils.DefaultOAuthTokenSecret
}
var accessToken, refreshToken string
if info.AccessToken != "" {
accessToken, err = utils.EncryptByAES([]byte(info.AccessToken), secret)
if err != nil {
return err
}
}
if info.RefreshToken != "" {
refreshToken, err = utils.EncryptByAES([]byte(info.RefreshToken), secret)
if err != nil {
return err
}
}
table := p.query.PluginOauthAuth
_, err = table.WithContext(ctx).
Select(table.ID).
Where(
table.UserID.Eq(meta.UserID),
table.PluginID.Eq(meta.PluginID),
table.IsDraft.Is(meta.IsDraft),
).First()
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
id, err := p.idGen.GenID(ctx)
if err != nil {
return err
}
po := &model.PluginOauthAuth{
ID: id,
UserID: meta.UserID,
PluginID: meta.PluginID,
IsDraft: meta.IsDraft,
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenExpiredAt: info.TokenExpiredAtMS,
NextTokenRefreshAt: info.GetNextTokenRefreshAtMS(),
OauthConfig: info.Config,
LastActiveAt: info.LastActiveAtMS,
}
return table.WithContext(ctx).Create(po)
}
updateMap := map[string]any{}
if accessToken != "" {
updateMap[table.AccessToken.ColumnName().String()] = accessToken
}
if refreshToken != "" {
updateMap[table.RefreshToken.ColumnName().String()] = refreshToken
}
if info.NextTokenRefreshAtMS != nil {
updateMap[table.NextTokenRefreshAt.ColumnName().String()] = *info.NextTokenRefreshAtMS
}
if info.TokenExpiredAtMS > 0 {
updateMap[table.TokenExpiredAt.ColumnName().String()] = info.TokenExpiredAtMS
}
if info.LastActiveAtMS > 0 {
updateMap[table.LastActiveAt.ColumnName().String()] = info.LastActiveAtMS
}
if info.Config != nil {
b, err := json.Marshal(info.Config)
if err != nil {
return err
}
updateMap[table.OauthConfig.ColumnName().String()] = b
}
_, err = table.WithContext(ctx).
Where(
table.UserID.Eq(meta.UserID),
table.PluginID.Eq(meta.PluginID),
table.IsDraft.Is(meta.IsDraft),
).
Updates(updateMap)
return err
}
func (p *PluginOAuthAuthDAO) UpdateLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) {
po := &model.PluginOauthAuth{
LastActiveAt: lastActiveAtMs,
}
table := p.query.PluginOauthAuth
_, err = table.WithContext(ctx).
Where(
table.UserID.Eq(meta.UserID),
table.PluginID.Eq(meta.PluginID),
table.IsDraft.Is(meta.IsDraft),
).
Updates(po)
return err
}
func (p *PluginOAuthAuthDAO) GetRefreshTokenList(ctx context.Context, nextRefreshAt int64, limit int) (infos []*entity.AuthorizationCodeInfo, err error) {
const size = 50
table := p.query.PluginOauthAuth
infos = make([]*entity.AuthorizationCodeInfo, 0, limit)
for limit > 0 {
res, err := table.WithContext(ctx).
Where(
table.NextTokenRefreshAt.Gt(0),
table.NextTokenRefreshAt.Lt(nextRefreshAt),
).
Order(table.NextTokenRefreshAt.Asc()).
Limit(size).
Find()
if err != nil {
return nil, err
}
infos = make([]*entity.AuthorizationCodeInfo, 0, len(res))
for _, v := range res {
infos = append(infos, pluginOAuthAuthPO(*v).ToDO())
}
limit -= size
if len(res) < size {
break
}
}
return infos, nil
}
func (p *PluginOAuthAuthDAO) BatchDeleteByIDs(ctx context.Context, ids []int64) (err error) {
table := p.query.PluginOauthAuth
chunks := slices.Chunks(ids, 20)
for _, chunk := range chunks {
_, err = table.WithContext(ctx).
Where(table.ID.In(chunk...)).
Delete()
if err != nil {
return err
}
}
return nil
}
func (p *PluginOAuthAuthDAO) Delete(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) {
table := p.query.PluginOauthAuth
_, err = table.WithContext(ctx).
Where(
table.UserID.Eq(meta.UserID),
table.PluginID.Eq(meta.PluginID),
table.IsDraft.Is(meta.IsDraft),
).
Delete()
return err
}
func (p *PluginOAuthAuthDAO) DeleteExpiredTokens(ctx context.Context, expireAt int64, limit int) (err error) {
const size = 50
table := p.query.PluginOauthAuth
for limit > 0 {
res, err := table.WithContext(ctx).
Where(
table.TokenExpiredAt.Gt(0),
table.TokenExpiredAt.Lt(expireAt),
).
Limit(size).
Delete()
if err != nil {
return err
}
limit -= size
if res.RowsAffected < size {
break
}
}
return nil
}
func (p *PluginOAuthAuthDAO) DeleteInactiveTokens(ctx context.Context, lastActiveAt int64, limit int) (err error) {
const size = 50
table := p.query.PluginOauthAuth
for limit > 0 {
res, err := table.WithContext(ctx).
Where(
table.LastActiveAt.Gt(0),
table.LastActiveAt.Lt(lastActiveAt),
).
Limit(size).
Delete()
if err != nil {
return err
}
limit -= size
if res.RowsAffected < size {
break
}
}
return nil
}