coze-studio/backend/domain/plugin/service/plugin_oauth.go

519 lines
14 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"sync"
"time"
"golang.org/x/oauth2"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
var (
initOnce = sync.Once{}
lastActiveInterval = 15 * 24 * time.Hour
)
func (p *pluginServiceImpl) processOAuthAccessToken(ctx context.Context) {
const (
deleteLimit = 100
refreshLimit = 50
)
for {
now := time.Now()
lastActiveAt := now.Add(-lastActiveInterval)
err := p.oauthRepo.DeleteInactiveAuthorizationCodeTokens(ctx, lastActiveAt.UnixMilli(), deleteLimit)
if err != nil {
logs.CtxWarnf(ctx, "DeleteInactiveAuthorizationCodeTokens failed, err=%v", err)
}
err = p.oauthRepo.DeleteExpiredAuthorizationCodeTokens(ctx, now.UnixMilli(), deleteLimit)
if err != nil {
logs.CtxWarnf(ctx, "DeleteExpiredAuthorizationCodeTokens failed, err=%v", err)
}
refreshTokenList, err := p.oauthRepo.GetAuthorizationCodeRefreshTokens(ctx, now.UnixMilli(), refreshLimit)
if err != nil {
logs.CtxErrorf(ctx, "GetAuthorizationCodeRefreshTokens failed, err=%v", err)
<-time.After(time.Second)
continue
}
taskGroups := taskgroup.NewTaskGroup(ctx, 3)
expired := make([]int64, 0, len(refreshTokenList))
for _, info := range refreshTokenList {
if info.GetNextTokenRefreshAtMS() == 0 || info.TokenExpiredAtMS == 0 {
continue
}
if info.GetNextTokenRefreshAtMS() > now.UnixMilli() ||
info.LastActiveAtMS <= lastActiveAt.UnixMilli() {
expired = append(expired, info.RecordID)
continue
}
taskGroups.Go(func() error {
p.refreshToken(ctx, info)
return nil
})
}
_ = taskGroups.Wait()
if len(expired) > 0 {
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, expired)
if err != nil {
logs.CtxWarnf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, err=%v", err)
}
}
<-time.After(5 * time.Second)
}
}
func (p *pluginServiceImpl) refreshToken(ctx context.Context, info *entity.AuthorizationCodeInfo) {
config := oauth2.Config{
ClientID: info.Config.ClientID,
ClientSecret: info.Config.ClientSecret,
Endpoint: oauth2.Endpoint{
TokenURL: info.Config.AuthorizationURL,
},
Scopes: strings.Split(info.Config.Scope, " "),
}
token := &oauth2.Token{
AccessToken: info.AccessToken,
RefreshToken: info.RefreshToken,
Expiry: time.UnixMilli(info.TokenExpiredAtMS),
}
source := config.TokenSource(ctx, token)
var (
err error
newToken *oauth2.Token
)
for i := 0; i < 3; i++ {
newToken, err = source.Token()
if err == nil {
token = newToken
break
}
<-time.After(time.Second)
}
if err != nil {
logs.CtxInfof(ctx, "refreshToken failed, recordID=%d, err=%v", info.RecordID, err)
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{info.RecordID})
if err != nil {
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", info.RecordID, err)
}
return
}
for i := 0; i < 3; i++ {
var expiredAtMS int64
if !token.Expiry.IsZero() && token.Expiry.After(time.Now()) {
expiredAtMS = token.Expiry.UnixMilli()
}
err = p.oauthRepo.UpsertAuthorizationCode(ctx, &entity.AuthorizationCodeInfo{
Meta: &entity.AuthorizationCodeMeta{
UserID: info.Meta.UserID,
PluginID: info.Meta.PluginID,
IsDraft: info.Meta.IsDraft,
},
Config: info.Config,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenExpiredAtMS: expiredAtMS,
NextTokenRefreshAtMS: ptr.Of(getNextTokenRefreshAtMS(expiredAtMS)),
})
if err == nil {
break
}
<-time.After(time.Second)
}
if err != nil {
logs.CtxInfof(ctx, "UpsertAuthorizationCode failed, recordID=%d, err=%v", info.RecordID, err)
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{info.RecordID})
if err != nil {
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", info.RecordID, err)
}
}
}
func (p *pluginServiceImpl) GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (accessToken string, err error) {
switch oa.OAuthMode {
case model.AuthzSubTypeOfOAuthAuthorizationCode:
accessToken, err = p.getAccessTokenByAuthorizationCode(ctx, oa.AuthorizationCode)
default:
return "", fmt.Errorf("invalid oauth mode '%s'", oa.OAuthMode)
}
if err != nil {
return "", err
}
return accessToken, nil
}
func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Context, ci *entity.AuthorizationCodeInfo) (accessToken string, err error) {
meta := ci.Meta
info, exist, err := p.oauthRepo.GetAuthorizationCode(ctx, ci.Meta)
if err != nil {
return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%t",
meta.UserID, meta.PluginID, meta.IsDraft)
}
if !exist {
return "", nil
}
if !isValidAuthCodeConfig(info.Config, ci.Config, info.TokenExpiredAtMS, info.LastActiveAtMS) {
return "", nil
}
now := time.Now().UnixMilli()
if now-info.LastActiveAtMS > time.Minute.Milliseconds() { // don't update too frequently
err = p.oauthRepo.UpdateAuthorizationCodeLastActiveAt(ctx, meta, now)
if err != nil {
logs.CtxWarnf(ctx, "UpdateAuthorizationCodeLastActiveAt failed, userID=%s, pluginID=%d, isDraft=%t, err=%v",
meta.UserID, meta.PluginID, meta.IsDraft, err)
}
}
return info.AccessToken, nil
}
func isValidAuthCodeConfig(o, n *model.OAuthAuthorizationCodeConfig, expireAt, lastActiveAt int64) bool {
now := time.Now()
if expireAt > 0 && expireAt <= now.UnixMilli() {
return false
}
if lastActiveAt > 0 && lastActiveAt <= now.Add(-lastActiveInterval).UnixMilli() {
return false
}
if o.ClientID != n.ClientID {
return false
}
if o.ClientSecret != n.ClientSecret {
return false
}
if o.ClientURL != n.ClientURL {
return false
}
if o.AuthorizationURL != n.AuthorizationURL {
return false
}
if o.AuthorizationContentType != n.AuthorizationContentType {
return false
}
oldScope := strings.Split(o.Scope, " ")
newScope := strings.Split(n.Scope, " ")
if len(oldScope) != len(newScope) {
return false
}
m := make(map[string]bool, len(oldScope))
for _, v := range oldScope {
m[v] = false
}
for _, v := range newScope {
if _, ok := m[v]; !ok {
return false
}
}
return true
}
func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *entity.OAuthState) (err error) {
var plugin *entity.PluginInfo
if state.IsDraft {
plugin, err = p.GetDraftPlugin(ctx, state.PluginID)
} else {
plugin, err = p.GetOnlinePlugin(ctx, state.PluginID)
}
if err != nil {
return errorx.Wrapf(err, "GetPlugin failed, pluginID=%d", state.PluginID)
}
authInfo := plugin.GetAuthInfo()
if authInfo.SubType != model.AuthzSubTypeOfOAuthAuthorizationCode {
return errorx.New(errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "plugin auth type is not oauth authorization code"))
}
if authInfo.AuthOfOAuthAuthorizationCode == nil {
return errorx.New(errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "plugin auth info is nil"))
}
config := getStanderOAuthConfig(authInfo.AuthOfOAuthAuthorizationCode)
token, err := config.Exchange(ctx, code)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "exchange token failed"))
}
meta := &entity.AuthorizationCodeMeta{
UserID: state.UserID,
PluginID: state.PluginID,
IsDraft: state.IsDraft,
}
var expiredAtMS int64
if !token.Expiry.IsZero() && token.Expiry.After(time.Now()) {
expiredAtMS = token.Expiry.UnixMilli()
}
err = p.saveAccessToken(ctx, &entity.OAuthInfo{
OAuthMode: model.AuthzSubTypeOfOAuthAuthorizationCode,
AuthorizationCode: &entity.AuthorizationCodeInfo{
Meta: meta,
Config: authInfo.AuthOfOAuthAuthorizationCode,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenExpiredAtMS: expiredAtMS,
NextTokenRefreshAtMS: ptr.Of(getNextTokenRefreshAtMS(expiredAtMS)),
LastActiveAtMS: time.Now().UnixMilli(),
},
})
if err != nil {
return errorx.Wrapf(err, "SaveAccessToken failed, pluginID=%d", state.PluginID)
}
return nil
}
func (p *pluginServiceImpl) saveAccessToken(ctx context.Context, oa *entity.OAuthInfo) (err error) {
switch oa.OAuthMode {
case model.AuthzSubTypeOfOAuthAuthorizationCode:
err = p.saveAuthCodeAccessToken(ctx, oa.AuthorizationCode)
default:
return fmt.Errorf("[standardOAuth] invalid oauth mode '%s'", oa.OAuthMode)
}
return err
}
func (p *pluginServiceImpl) saveAuthCodeAccessToken(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) {
meta := info.Meta
err = p.oauthRepo.UpsertAuthorizationCode(ctx, info)
if err != nil {
return errorx.Wrapf(err, "SaveAuthorizationCodeInfo failed, userID=%s, pluginID=%d, isDraft=%t",
meta.UserID, meta.PluginID, meta.IsDraft)
}
return nil
}
func getNextTokenRefreshAtMS(expiredAtMS int64) int64 {
if expiredAtMS == 0 {
return 0
}
return time.Now().Add(time.Duration((expiredAtMS-time.Now().UnixMilli())/2) * time.Millisecond).UnixMilli()
}
func (p *pluginServiceImpl) RevokeAccessToken(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) {
return p.oauthRepo.DeleteAuthorizationCode(ctx, meta)
}
func (p *pluginServiceImpl) GetOAuthStatus(ctx context.Context, userID, pluginID int64) (resp *GetOAuthStatusResponse, err error) {
pl, exist, err := p.pluginRepo.GetDraftPlugin(ctx, pluginID)
if err != nil {
return nil, err
}
if !exist {
return nil, fmt.Errorf("draft plugin '%d' not found", pluginID)
}
authInfo := pl.GetAuthInfo()
if authInfo.Type == model.AuthzTypeOfNone || authInfo.Type == model.AuthzTypeOfService {
return &GetOAuthStatusResponse{
IsOauth: false,
}, nil
}
needAuth, authURL, err := p.getPluginOAuthStatus(ctx, userID, pl, true)
if err != nil {
return nil, err
}
status := common.OAuthStatus_Authorized
if needAuth {
status = common.OAuthStatus_Unauthorized
}
resp = &GetOAuthStatusResponse{
IsOauth: true,
Status: status,
OAuthURL: authURL,
}
return resp, nil
}
func (p *pluginServiceImpl) getPluginOAuthStatus(ctx context.Context, userID int64, plugin *entity.PluginInfo, isDraft bool) (needAuth bool, authURL string, err error) {
authInfo := plugin.GetAuthInfo()
if authInfo.Type != model.AuthzTypeOfOAuth {
return false, "", fmt.Errorf("invalid auth type '%v'", authInfo.Type)
}
if authInfo.SubType != model.AuthzSubTypeOfOAuthAuthorizationCode {
return false, "", fmt.Errorf("invalid auth sub type '%v'", authInfo.SubType)
}
authCode := &entity.AuthorizationCodeInfo{
Meta: &entity.AuthorizationCodeMeta{
UserID: conv.Int64ToStr(userID),
PluginID: plugin.ID,
IsDraft: isDraft,
},
Config: plugin.Manifest.Auth.AuthOfOAuthAuthorizationCode,
}
accessToken, err := p.GetAccessToken(ctx, &entity.OAuthInfo{
OAuthMode: model.AuthzSubTypeOfOAuthAuthorizationCode,
AuthorizationCode: authCode,
})
if err != nil {
return false, "", err
}
needAuth = accessToken == ""
authURL, err = genAuthURL(authCode)
if err != nil {
return false, "", err
}
return needAuth, authURL, nil
}
func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) {
config := getStanderOAuthConfig(info.Config)
state := &entity.OAuthState{
ClientName: "",
UserID: info.Meta.UserID,
PluginID: info.Meta.PluginID,
IsDraft: info.Meta.IsDraft,
}
stateStr, err := json.Marshal(state)
if err != nil {
return "", fmt.Errorf("marshal state failed, err=%v", err)
}
secret := os.Getenv(utils.StateSecretEnv)
if secret == "" {
secret = utils.DefaultStateSecret
}
encryptState, err := utils.EncryptByAES(stateStr, secret)
if err != nil {
return "", fmt.Errorf("encrypt state failed, err=%v", err)
}
authURL := config.AuthCodeURL(encryptState)
return authURL, nil
}
func getStanderOAuthConfig(config *model.OAuthAuthorizationCodeConfig) *oauth2.Config {
if config == nil {
return nil
}
return &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: oauth2.Endpoint{
TokenURL: config.AuthorizationURL,
AuthURL: config.ClientURL,
},
RedirectURL: fmt.Sprintf("https://%s/api/oauth/authorization_code", os.Getenv("SERVER_HOST")),
Scopes: strings.Split(config.Scope, " "),
}
}
func (p *pluginServiceImpl) GetAgentPluginsOAuthStatus(ctx context.Context, userID, agentID int64) (status []*AgentPluginOAuthStatus, err error) {
pluginIDs, err := p.toolRepo.GetAgentPluginIDs(ctx, agentID)
if err != nil {
return nil, errorx.Wrapf(err, "GetAgentPluginIDs failed, agentID=%d", agentID)
}
if len(pluginIDs) == 0 {
return nil, nil
}
plugins, err := p.pluginRepo.MGetOnlinePlugins(ctx, pluginIDs)
if err != nil {
return nil, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs)
}
for _, plugin := range plugins {
authInfo := plugin.GetAuthInfo()
if authInfo.Type == model.AuthzTypeOfNone || authInfo.Type == model.AuthzTypeOfService {
continue
}
needAuth, _, err := p.getPluginOAuthStatus(ctx, userID, plugin, false)
if err != nil {
logs.CtxErrorf(ctx, "getPluginOAuthStatus failed, pluginID=%d, err=%v", plugin.ID, err)
continue
}
iconURL := ""
if plugin.GetIconURI() != "" {
iconURL, _ = p.oss.GetObjectUrl(ctx, plugin.GetIconURI())
}
authStatus := common.OAuthStatus_Authorized
if needAuth {
authStatus = common.OAuthStatus_Unauthorized
}
status = append(status, &AgentPluginOAuthStatus{
PluginID: plugin.ID,
PluginName: plugin.GetName(),
PluginIconURL: iconURL,
Status: authStatus,
})
}
return status, nil
}