331 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			331 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package dal
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"os"
 | 
						|
 | 
						|
	"gorm.io/gorm"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
 | 
						|
)
 | 
						|
 | 
						|
func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAuthDAO {
 | 
						|
	return &PluginOAuthAuthDAO{
 | 
						|
		idGen: idGen,
 | 
						|
		query: query.Use(db),
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type pluginOAuthAuthPO model.PluginOauthAuth
 | 
						|
 | 
						|
func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo {
 | 
						|
	secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
 | 
						|
	if secret == "" {
 | 
						|
		secret = encrypt.DefaultOAuthTokenSecret
 | 
						|
	}
 | 
						|
 | 
						|
	if p.RefreshToken != "" {
 | 
						|
		refreshToken, err := encrypt.DecryptByAES(p.RefreshToken, secret)
 | 
						|
		if err == nil {
 | 
						|
			p.RefreshToken = string(refreshToken)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if p.AccessToken != "" {
 | 
						|
		accessToken, err := encrypt.DecryptByAES(p.AccessToken, secret)
 | 
						|
		if err == nil {
 | 
						|
			p.AccessToken = string(accessToken)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return &entity.AuthorizationCodeInfo{
 | 
						|
		RecordID: p.ID,
 | 
						|
		Meta: &entity.AuthorizationCodeMeta{
 | 
						|
			UserID:   p.UserID,
 | 
						|
			PluginID: p.PluginID,
 | 
						|
			IsDraft:  p.IsDraft,
 | 
						|
		},
 | 
						|
		Config:               p.OauthConfig,
 | 
						|
		AccessToken:          p.AccessToken,
 | 
						|
		RefreshToken:         p.RefreshToken,
 | 
						|
		TokenExpiredAtMS:     p.TokenExpiredAt,
 | 
						|
		NextTokenRefreshAtMS: &p.NextTokenRefreshAt,
 | 
						|
		LastActiveAtMS:       p.LastActiveAt,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type PluginOAuthAuthDAO struct {
 | 
						|
	idGen idgen.IDGenerator
 | 
						|
	query *query.Query
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) Get(ctx context.Context, meta *entity.AuthorizationCodeMeta) (info *entity.AuthorizationCodeInfo, exist bool, err error) {
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
	res, err := table.WithContext(ctx).
 | 
						|
		Where(
 | 
						|
			table.UserID.Eq(meta.UserID),
 | 
						|
			table.PluginID.Eq(meta.PluginID),
 | 
						|
			table.IsDraft.Is(meta.IsDraft),
 | 
						|
		).
 | 
						|
		First()
 | 
						|
	if err != nil {
 | 
						|
		if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
						|
			return nil, false, nil
 | 
						|
		}
 | 
						|
		return nil, false, err
 | 
						|
	}
 | 
						|
 | 
						|
	info = pluginOAuthAuthPO(*res).ToDO()
 | 
						|
 | 
						|
	return info, true, nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) {
 | 
						|
	if info.Meta == nil || info.Meta.UserID == "" || info.Meta.PluginID <= 0 {
 | 
						|
		return fmt.Errorf("meta info is required")
 | 
						|
	}
 | 
						|
 | 
						|
	meta := info.Meta
 | 
						|
	secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
 | 
						|
	if secret == "" {
 | 
						|
		secret = encrypt.DefaultOAuthTokenSecret
 | 
						|
	}
 | 
						|
 | 
						|
	var accessToken, refreshToken string
 | 
						|
	if info.AccessToken != "" {
 | 
						|
		accessToken, err = encrypt.EncryptByAES([]byte(info.AccessToken), secret)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if info.RefreshToken != "" {
 | 
						|
		refreshToken, err = encrypt.EncryptByAES([]byte(info.RefreshToken), secret)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
	_, err = table.WithContext(ctx).
 | 
						|
		Select(table.ID).
 | 
						|
		Where(
 | 
						|
			table.UserID.Eq(meta.UserID),
 | 
						|
			table.PluginID.Eq(meta.PluginID),
 | 
						|
			table.IsDraft.Is(meta.IsDraft),
 | 
						|
		).First()
 | 
						|
	if err != nil {
 | 
						|
		if !errors.Is(err, gorm.ErrRecordNotFound) {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		id, err := p.idGen.GenID(ctx)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		po := &model.PluginOauthAuth{
 | 
						|
			ID:                 id,
 | 
						|
			UserID:             meta.UserID,
 | 
						|
			PluginID:           meta.PluginID,
 | 
						|
			IsDraft:            meta.IsDraft,
 | 
						|
			AccessToken:        accessToken,
 | 
						|
			RefreshToken:       refreshToken,
 | 
						|
			TokenExpiredAt:     info.TokenExpiredAtMS,
 | 
						|
			NextTokenRefreshAt: info.GetNextTokenRefreshAtMS(),
 | 
						|
			OauthConfig:        info.Config,
 | 
						|
			LastActiveAt:       info.LastActiveAtMS,
 | 
						|
		}
 | 
						|
 | 
						|
		return table.WithContext(ctx).Create(po)
 | 
						|
	}
 | 
						|
 | 
						|
	updateMap := map[string]any{}
 | 
						|
	if accessToken != "" {
 | 
						|
		updateMap[table.AccessToken.ColumnName().String()] = accessToken
 | 
						|
	}
 | 
						|
	if refreshToken != "" {
 | 
						|
		updateMap[table.RefreshToken.ColumnName().String()] = refreshToken
 | 
						|
	}
 | 
						|
	if info.NextTokenRefreshAtMS != nil {
 | 
						|
		updateMap[table.NextTokenRefreshAt.ColumnName().String()] = *info.NextTokenRefreshAtMS
 | 
						|
	}
 | 
						|
	if info.TokenExpiredAtMS > 0 {
 | 
						|
		updateMap[table.TokenExpiredAt.ColumnName().String()] = info.TokenExpiredAtMS
 | 
						|
	}
 | 
						|
	if info.LastActiveAtMS > 0 {
 | 
						|
		updateMap[table.LastActiveAt.ColumnName().String()] = info.LastActiveAtMS
 | 
						|
	}
 | 
						|
	if info.Config != nil {
 | 
						|
		b, err := json.Marshal(info.Config)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		updateMap[table.OauthConfig.ColumnName().String()] = b
 | 
						|
	}
 | 
						|
 | 
						|
	_, err = table.WithContext(ctx).
 | 
						|
		Where(
 | 
						|
			table.UserID.Eq(meta.UserID),
 | 
						|
			table.PluginID.Eq(meta.PluginID),
 | 
						|
			table.IsDraft.Is(meta.IsDraft),
 | 
						|
		).
 | 
						|
		Updates(updateMap)
 | 
						|
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) UpdateLastActiveAt(ctx context.Context, meta *entity.AuthorizationCodeMeta, lastActiveAtMs int64) (err error) {
 | 
						|
	po := &model.PluginOauthAuth{
 | 
						|
		LastActiveAt: lastActiveAtMs,
 | 
						|
	}
 | 
						|
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
	_, err = table.WithContext(ctx).
 | 
						|
		Where(
 | 
						|
			table.UserID.Eq(meta.UserID),
 | 
						|
			table.PluginID.Eq(meta.PluginID),
 | 
						|
			table.IsDraft.Is(meta.IsDraft),
 | 
						|
		).
 | 
						|
		Updates(po)
 | 
						|
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) GetRefreshTokenList(ctx context.Context, nextRefreshAt int64, limit int) (infos []*entity.AuthorizationCodeInfo, err error) {
 | 
						|
	const size = 50
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
 | 
						|
	infos = make([]*entity.AuthorizationCodeInfo, 0, limit)
 | 
						|
 | 
						|
	for limit > 0 {
 | 
						|
		res, err := table.WithContext(ctx).
 | 
						|
			Where(
 | 
						|
				table.NextTokenRefreshAt.Gt(0),
 | 
						|
				table.NextTokenRefreshAt.Lt(nextRefreshAt),
 | 
						|
			).
 | 
						|
			Order(table.NextTokenRefreshAt.Asc()).
 | 
						|
			Limit(size).
 | 
						|
			Find()
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		infos = make([]*entity.AuthorizationCodeInfo, 0, len(res))
 | 
						|
		for _, v := range res {
 | 
						|
			infos = append(infos, pluginOAuthAuthPO(*v).ToDO())
 | 
						|
		}
 | 
						|
 | 
						|
		limit -= size
 | 
						|
 | 
						|
		if len(res) < size {
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return infos, nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) BatchDeleteByIDs(ctx context.Context, ids []int64) (err error) {
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
 | 
						|
	chunks := slices.Chunks(ids, 20)
 | 
						|
 | 
						|
	for _, chunk := range chunks {
 | 
						|
		_, err = table.WithContext(ctx).
 | 
						|
			Where(table.ID.In(chunk...)).
 | 
						|
			Delete()
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) Delete(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) {
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
	_, err = table.WithContext(ctx).
 | 
						|
		Where(
 | 
						|
			table.UserID.Eq(meta.UserID),
 | 
						|
			table.PluginID.Eq(meta.PluginID),
 | 
						|
			table.IsDraft.Is(meta.IsDraft),
 | 
						|
		).
 | 
						|
		Delete()
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) DeleteExpiredTokens(ctx context.Context, expireAt int64, limit int) (err error) {
 | 
						|
	const size = 50
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
 | 
						|
	for limit > 0 {
 | 
						|
		res, err := table.WithContext(ctx).
 | 
						|
			Where(
 | 
						|
				table.TokenExpiredAt.Gt(0),
 | 
						|
				table.TokenExpiredAt.Lt(expireAt),
 | 
						|
			).
 | 
						|
			Limit(size).
 | 
						|
			Delete()
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		limit -= size
 | 
						|
 | 
						|
		if res.RowsAffected < size {
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *PluginOAuthAuthDAO) DeleteInactiveTokens(ctx context.Context, lastActiveAt int64, limit int) (err error) {
 | 
						|
	const size = 50
 | 
						|
	table := p.query.PluginOauthAuth
 | 
						|
 | 
						|
	for limit > 0 {
 | 
						|
		res, err := table.WithContext(ctx).
 | 
						|
			Where(
 | 
						|
				table.LastActiveAt.Gt(0),
 | 
						|
				table.LastActiveAt.Lt(lastActiveAt),
 | 
						|
			).
 | 
						|
			Limit(size).
 | 
						|
			Delete()
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		limit -= size
 | 
						|
 | 
						|
		if res.RowsAffected < size {
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 |