fix(plugin): authorization code redirect to static url (#191)

This commit is contained in:
mrh997 2025-08-07 12:24:18 +08:00 committed by GitHub
parent efbc82e8b3
commit e2b1f6e381
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 79 additions and 63 deletions

View File

@ -20,12 +20,14 @@ package coze
import ( import (
"context" "context"
"fmt"
"github.com/coze-dev/coze-studio/backend/application/plugin"
"github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/coze-dev/coze-studio/backend/application/plugin"
bot_open_api "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_open_api" bot_open_api "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_open_api"
) )
@ -41,7 +43,7 @@ func OauthAuthorizationCode(ctx context.Context, c *app.RequestContext) {
} }
if req.Code == "" { if req.Code == "" {
invalidParamRequestResponse(c, "code is required") invalidParamRequestResponse(c, "authorization failed, code is required")
return return
} }
if req.State == "" { if req.State == "" {
@ -49,11 +51,15 @@ func OauthAuthorizationCode(ctx context.Context, c *app.RequestContext) {
return return
} }
resp, err := plugin.PluginApplicationSVC.OauthAuthorizationCode(ctx, &req) _, err = plugin.PluginApplicationSVC.OauthAuthorizationCode(ctx, &req)
if err != nil { if err != nil {
internalServerErrorResponse(ctx, c, err) internalServerErrorResponse(ctx, c, err)
return return
} }
c.JSON(consts.StatusOK, resp) redirectURL := fmt.Sprintf("%s/information/auth/success", conf.GetServerHost())
c.Redirect(consts.StatusFound, []byte(redirectURL))
c.Abort()
return
} }

View File

@ -61,9 +61,13 @@ func isStaticFile(ctx *app.RequestContext) bool {
return true return true
} }
if strings.HasPrefix(string(path), "/static/") || if strings.HasPrefix(path, "/static/") ||
strings.HasPrefix(string(path), "/explore/") || strings.HasPrefix(path, "/explore/") ||
strings.HasPrefix(string(path), "/space/") { strings.HasPrefix(path, "/space/") {
return true
}
if path == "/information/auth/success" {
return true return true
} }

View File

@ -23,7 +23,7 @@ import (
"strings" "strings"
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils" "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
@ -75,12 +75,12 @@ func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
return mf_, nil return mf_, nil
} }
secret := os.Getenv(utils.AuthSecretEnv) secret := os.Getenv(encrypt.AuthSecretEnv)
if secret == "" { if secret == "" {
secret = utils.DefaultAuthSecret secret = encrypt.DefaultAuthSecret
} }
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), secret) payload_, err := encrypt.EncryptByAES([]byte(mf_.Auth.Payload), secret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -363,12 +363,12 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error {
} }
if auth.Payload != "" { if auth.Payload != "" {
secret := os.Getenv(utils.AuthSecretEnv) secret := os.Getenv(encrypt.AuthSecretEnv)
if secret == "" { if secret == "" {
secret = utils.DefaultAuthSecret secret = encrypt.DefaultAuthSecret
} }
payload_, err := utils.DecryptByAES(auth.Payload, secret) payload_, err := encrypt.DecryptByAES(auth.Payload, secret)
if err == nil { if err == nil {
auth.Payload = string(payload_) auth.Payload = string(payload_)
} }

View File

@ -45,10 +45,10 @@ import (
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil" "github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository" "github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service" "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/service" search "github.com/coze-dev/coze-studio/backend/domain/search/service"
user "github.com/coze-dev/coze-studio/backend/domain/user/service" user "github.com/coze-dev/coze-studio/backend/domain/user/service"
@ -1704,12 +1704,12 @@ func (p *PluginApplicationService) OauthAuthorizationCode(ctx context.Context, r
return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state"))
} }
secret := os.Getenv(utils.StateSecretEnv) secret := os.Getenv(encrypt.StateSecretEnv)
if secret == "" { if secret == "" {
secret = utils.DefaultStateSecret secret = encrypt.DefaultStateSecret
} }
stateBytes, err := utils.DecryptByAES(stateStr, secret) stateBytes, err := encrypt.DecryptByAES(stateStr, secret)
if err != nil { if err != nil {
return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state"))
} }

View File

@ -0,0 +1,30 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conf
import (
"os"
"strings"
)
func GetServerHost() string {
host := os.Getenv("SERVER_HOST")
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
return host
}
return "https://" + host
}

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package utils package encrypt
import ( import (
"bytes" "bytes"

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package utils package encrypt
import ( import (
"testing" "testing"

View File

@ -16,25 +16,6 @@
package entity package entity
import (
"strings"
)
const (
larkPluginOAuthHostName = "open.larkoffice.com"
larkOAuthHostName = "open.feishu.cn"
)
func GetOAuthProvider(tokenURL string) OAuthProvider {
if strings.Contains(tokenURL, larkPluginOAuthHostName) {
return OAuthProviderOfLarkPlugin
}
if strings.Contains(tokenURL, larkOAuthHostName) {
return OAuthProviderOfLark
}
return OAuthProviderOfStandard
}
type SortField string type SortField string
const ( const (
@ -43,9 +24,3 @@ const (
) )
type OAuthProvider string type OAuthProvider string
const (
OAuthProviderOfLarkPlugin OAuthProvider = "lark_plugin"
OAuthProviderOfLark OAuthProvider = "lark"
OAuthProviderOfStandard OAuthProvider = "standard"
)

View File

@ -25,10 +25,10 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
) )
@ -43,19 +43,19 @@ func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAut
type pluginOAuthAuthPO model.PluginOauthAuth type pluginOAuthAuthPO model.PluginOauthAuth
func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo { func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo {
secret := os.Getenv(utils.OAuthTokenSecretEnv) secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
if secret == "" { if secret == "" {
secret = utils.DefaultOAuthTokenSecret secret = encrypt.DefaultOAuthTokenSecret
} }
if p.RefreshToken != "" { if p.RefreshToken != "" {
refreshToken, err := utils.DecryptByAES(p.RefreshToken, secret) refreshToken, err := encrypt.DecryptByAES(p.RefreshToken, secret)
if err == nil { if err == nil {
p.RefreshToken = string(refreshToken) p.RefreshToken = string(refreshToken)
} }
} }
if p.AccessToken != "" { if p.AccessToken != "" {
accessToken, err := utils.DecryptByAES(p.AccessToken, secret) accessToken, err := encrypt.DecryptByAES(p.AccessToken, secret)
if err == nil { if err == nil {
p.AccessToken = string(accessToken) p.AccessToken = string(accessToken)
} }
@ -109,20 +109,20 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat
} }
meta := info.Meta meta := info.Meta
secret := os.Getenv(utils.OAuthTokenSecretEnv) secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
if secret == "" { if secret == "" {
secret = utils.DefaultOAuthTokenSecret secret = encrypt.DefaultOAuthTokenSecret
} }
var accessToken, refreshToken string var accessToken, refreshToken string
if info.AccessToken != "" { if info.AccessToken != "" {
accessToken, err = utils.EncryptByAES([]byte(info.AccessToken), secret) accessToken, err = encrypt.EncryptByAES([]byte(info.AccessToken), secret)
if err != nil { if err != nil {
return err return err
} }
} }
if info.RefreshToken != "" { if info.RefreshToken != "" {
refreshToken, err = utils.EncryptByAES([]byte(info.RefreshToken), secret) refreshToken, err = encrypt.EncryptByAES([]byte(info.RefreshToken), secret)
if err != nil { if err != nil {
return err return err
} }

View File

@ -29,8 +29,9 @@ import (
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
@ -438,12 +439,12 @@ func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) {
return "", fmt.Errorf("marshal state failed, err=%v", err) return "", fmt.Errorf("marshal state failed, err=%v", err)
} }
secret := os.Getenv(utils.StateSecretEnv) secret := os.Getenv(encrypt.StateSecretEnv)
if secret == "" { if secret == "" {
secret = utils.DefaultStateSecret secret = encrypt.DefaultStateSecret
} }
encryptState, err := utils.EncryptByAES(stateStr, secret) encryptState, err := encrypt.EncryptByAES(stateStr, secret)
if err != nil { if err != nil {
return "", fmt.Errorf("encrypt state failed, err=%v", err) return "", fmt.Errorf("encrypt state failed, err=%v", err)
} }
@ -464,7 +465,7 @@ func getStanderOAuthConfig(config *model.OAuthAuthorizationCodeConfig) *oauth2.C
TokenURL: config.AuthorizationURL, TokenURL: config.AuthorizationURL,
AuthURL: config.ClientURL, AuthURL: config.ClientURL,
}, },
RedirectURL: fmt.Sprintf("https://%s/api/oauth/authorization_code", os.Getenv("SERVER_HOST")), RedirectURL: fmt.Sprintf("%s/api/oauth/authorization_code", conf.GetServerHost()),
Scopes: strings.Split(config.Scope, " "), Scopes: strings.Split(config.Scope, " "),
} }
} }

View File

@ -5,7 +5,7 @@ export RUN_MODE="debug" # Currently supports debug mode. When set to debug, it h
export LISTEN_ADDR=":8888" export LISTEN_ADDR=":8888"
export LOG_LEVEL="debug" export LOG_LEVEL="debug"
export MAX_REQUEST_BODY_SIZE=1073741824 export MAX_REQUEST_BODY_SIZE=1073741824
export SERVER_HOST="localhost${LISTEN_ADDR}" export SERVER_HOST="http://localhost${LISTEN_ADDR}"
export MINIO_PROXY_ENDPOINT="" export MINIO_PROXY_ENDPOINT=""
export USE_SSL="0" export USE_SSL="0"
export SSL_CERT_FILE="" export SSL_CERT_FILE=""

View File

@ -2,8 +2,8 @@
export LISTEN_ADDR=":8888" export LISTEN_ADDR=":8888"
export LOG_LEVEL="debug" export LOG_LEVEL="debug"
export MAX_REQUEST_BODY_SIZE=1073741824 export MAX_REQUEST_BODY_SIZE=1073741824
export SERVER_HOST="localhost${LISTEN_ADDR}" export SERVER_HOST="http://localhost${LISTEN_ADDR}"
export MINIO_PROXY_ENDPOINT=":8889" export MINIO_PROXY_ENDPOINT=""
export USE_SSL="0" export USE_SSL="0"
export SSL_CERT_FILE="" export SSL_CERT_FILE=""
export SSL_KEY_FILE="" export SSL_KEY_FILE=""