fix(plugin): authorization code redirect to static url (#191)

This commit is contained in:
mrh997 2025-08-07 12:24:18 +08:00 committed by GitHub
parent efbc82e8b3
commit e2b1f6e381
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 79 additions and 63 deletions

View File

@ -20,12 +20,14 @@ package coze
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/application/plugin"
"github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/coze-dev/coze-studio/backend/application/plugin"
bot_open_api "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_open_api"
)
@ -41,7 +43,7 @@ func OauthAuthorizationCode(ctx context.Context, c *app.RequestContext) {
}
if req.Code == "" {
invalidParamRequestResponse(c, "code is required")
invalidParamRequestResponse(c, "authorization failed, code is required")
return
}
if req.State == "" {
@ -49,11 +51,15 @@ func OauthAuthorizationCode(ctx context.Context, c *app.RequestContext) {
return
}
resp, err := plugin.PluginApplicationSVC.OauthAuthorizationCode(ctx, &req)
_, err = plugin.PluginApplicationSVC.OauthAuthorizationCode(ctx, &req)
if err != nil {
internalServerErrorResponse(ctx, c, err)
return
}
c.JSON(consts.StatusOK, resp)
redirectURL := fmt.Sprintf("%s/information/auth/success", conf.GetServerHost())
c.Redirect(consts.StatusFound, []byte(redirectURL))
c.Abort()
return
}

View File

@ -61,9 +61,13 @@ func isStaticFile(ctx *app.RequestContext) bool {
return true
}
if strings.HasPrefix(string(path), "/static/") ||
strings.HasPrefix(string(path), "/explore/") ||
strings.HasPrefix(string(path), "/space/") {
if strings.HasPrefix(path, "/static/") ||
strings.HasPrefix(path, "/explore/") ||
strings.HasPrefix(path, "/space/") {
return true
}
if path == "/information/auth/success" {
return true
}

View File

@ -23,7 +23,7 @@ import (
"strings"
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
@ -75,12 +75,12 @@ func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
return mf_, nil
}
secret := os.Getenv(utils.AuthSecretEnv)
secret := os.Getenv(encrypt.AuthSecretEnv)
if secret == "" {
secret = utils.DefaultAuthSecret
secret = encrypt.DefaultAuthSecret
}
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), secret)
payload_, err := encrypt.EncryptByAES([]byte(mf_.Auth.Payload), secret)
if err != nil {
return nil, err
}
@ -363,12 +363,12 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error {
}
if auth.Payload != "" {
secret := os.Getenv(utils.AuthSecretEnv)
secret := os.Getenv(encrypt.AuthSecretEnv)
if secret == "" {
secret = utils.DefaultAuthSecret
secret = encrypt.DefaultAuthSecret
}
payload_, err := utils.DecryptByAES(auth.Payload, secret)
payload_, err := encrypt.DecryptByAES(auth.Payload, secret)
if err == nil {
auth.Payload = string(payload_)
}

View File

@ -45,10 +45,10 @@ import (
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
@ -1704,12 +1704,12 @@ func (p *PluginApplicationService) OauthAuthorizationCode(ctx context.Context, r
return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state"))
}
secret := os.Getenv(utils.StateSecretEnv)
secret := os.Getenv(encrypt.StateSecretEnv)
if secret == "" {
secret = utils.DefaultStateSecret
secret = encrypt.DefaultStateSecret
}
stateBytes, err := utils.DecryptByAES(stateStr, secret)
stateBytes, err := encrypt.DecryptByAES(stateStr, secret)
if err != nil {
return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state"))
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conf
import (
"os"
"strings"
)
func GetServerHost() string {
host := os.Getenv("SERVER_HOST")
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
return host
}
return "https://" + host
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package utils
package encrypt
import (
"bytes"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package utils
package encrypt
import (
"testing"

View File

@ -16,25 +16,6 @@
package entity
import (
"strings"
)
const (
larkPluginOAuthHostName = "open.larkoffice.com"
larkOAuthHostName = "open.feishu.cn"
)
func GetOAuthProvider(tokenURL string) OAuthProvider {
if strings.Contains(tokenURL, larkPluginOAuthHostName) {
return OAuthProviderOfLarkPlugin
}
if strings.Contains(tokenURL, larkOAuthHostName) {
return OAuthProviderOfLark
}
return OAuthProviderOfStandard
}
type SortField string
const (
@ -43,9 +24,3 @@ const (
)
type OAuthProvider string
const (
OAuthProviderOfLarkPlugin OAuthProvider = "lark_plugin"
OAuthProviderOfLark OAuthProvider = "lark"
OAuthProviderOfStandard OAuthProvider = "standard"
)

View File

@ -25,10 +25,10 @@ import (
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
@ -43,19 +43,19 @@ func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAut
type pluginOAuthAuthPO model.PluginOauthAuth
func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo {
secret := os.Getenv(utils.OAuthTokenSecretEnv)
secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
if secret == "" {
secret = utils.DefaultOAuthTokenSecret
secret = encrypt.DefaultOAuthTokenSecret
}
if p.RefreshToken != "" {
refreshToken, err := utils.DecryptByAES(p.RefreshToken, secret)
refreshToken, err := encrypt.DecryptByAES(p.RefreshToken, secret)
if err == nil {
p.RefreshToken = string(refreshToken)
}
}
if p.AccessToken != "" {
accessToken, err := utils.DecryptByAES(p.AccessToken, secret)
accessToken, err := encrypt.DecryptByAES(p.AccessToken, secret)
if err == nil {
p.AccessToken = string(accessToken)
}
@ -109,20 +109,20 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat
}
meta := info.Meta
secret := os.Getenv(utils.OAuthTokenSecretEnv)
secret := os.Getenv(encrypt.OAuthTokenSecretEnv)
if secret == "" {
secret = utils.DefaultOAuthTokenSecret
secret = encrypt.DefaultOAuthTokenSecret
}
var accessToken, refreshToken string
if info.AccessToken != "" {
accessToken, err = utils.EncryptByAES([]byte(info.AccessToken), secret)
accessToken, err = encrypt.EncryptByAES([]byte(info.AccessToken), secret)
if err != nil {
return err
}
}
if info.RefreshToken != "" {
refreshToken, err = utils.EncryptByAES([]byte(info.RefreshToken), secret)
refreshToken, err = encrypt.EncryptByAES([]byte(info.RefreshToken), secret)
if err != nil {
return err
}

View File

@ -29,8 +29,9 @@ import (
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
@ -438,12 +439,12 @@ func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) {
return "", fmt.Errorf("marshal state failed, err=%v", err)
}
secret := os.Getenv(utils.StateSecretEnv)
secret := os.Getenv(encrypt.StateSecretEnv)
if secret == "" {
secret = utils.DefaultStateSecret
secret = encrypt.DefaultStateSecret
}
encryptState, err := utils.EncryptByAES(stateStr, secret)
encryptState, err := encrypt.EncryptByAES(stateStr, secret)
if err != nil {
return "", fmt.Errorf("encrypt state failed, err=%v", err)
}
@ -464,7 +465,7 @@ func getStanderOAuthConfig(config *model.OAuthAuthorizationCodeConfig) *oauth2.C
TokenURL: config.AuthorizationURL,
AuthURL: config.ClientURL,
},
RedirectURL: fmt.Sprintf("https://%s/api/oauth/authorization_code", os.Getenv("SERVER_HOST")),
RedirectURL: fmt.Sprintf("%s/api/oauth/authorization_code", conf.GetServerHost()),
Scopes: strings.Split(config.Scope, " "),
}
}

View File

@ -5,7 +5,7 @@ export RUN_MODE="debug" # Currently supports debug mode. When set to debug, it h
export LISTEN_ADDR=":8888"
export LOG_LEVEL="debug"
export MAX_REQUEST_BODY_SIZE=1073741824
export SERVER_HOST="localhost${LISTEN_ADDR}"
export SERVER_HOST="http://localhost${LISTEN_ADDR}"
export MINIO_PROXY_ENDPOINT=""
export USE_SSL="0"
export SSL_CERT_FILE=""

View File

@ -2,8 +2,8 @@
export LISTEN_ADDR=":8888"
export LOG_LEVEL="debug"
export MAX_REQUEST_BODY_SIZE=1073741824
export SERVER_HOST="localhost${LISTEN_ADDR}"
export MINIO_PROXY_ENDPOINT=":8889"
export SERVER_HOST="http://localhost${LISTEN_ADDR}"
export MINIO_PROXY_ENDPOINT=""
export USE_SSL="0"
export SSL_CERT_FILE=""
export SSL_KEY_FILE=""