diff --git a/backend/api/handler/coze/bot_open_api_service.go b/backend/api/handler/coze/bot_open_api_service.go index f470f9f7..a7cf32aa 100644 --- a/backend/api/handler/coze/bot_open_api_service.go +++ b/backend/api/handler/coze/bot_open_api_service.go @@ -20,12 +20,14 @@ package coze import ( "context" + "fmt" + + "github.com/coze-dev/coze-studio/backend/application/plugin" + "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" - "github.com/coze-dev/coze-studio/backend/application/plugin" - bot_open_api "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_open_api" ) @@ -41,7 +43,7 @@ func OauthAuthorizationCode(ctx context.Context, c *app.RequestContext) { } if req.Code == "" { - invalidParamRequestResponse(c, "code is required") + invalidParamRequestResponse(c, "authorization failed, code is required") return } if req.State == "" { @@ -49,11 +51,15 @@ func OauthAuthorizationCode(ctx context.Context, c *app.RequestContext) { return } - resp, err := plugin.PluginApplicationSVC.OauthAuthorizationCode(ctx, &req) + _, err = plugin.PluginApplicationSVC.OauthAuthorizationCode(ctx, &req) if err != nil { internalServerErrorResponse(ctx, c, err) return } - c.JSON(consts.StatusOK, resp) + redirectURL := fmt.Sprintf("%s/information/auth/success", conf.GetServerHost()) + c.Redirect(consts.StatusFound, []byte(redirectURL)) + c.Abort() + + return } diff --git a/backend/api/middleware/request_inspector.go b/backend/api/middleware/request_inspector.go index 0a670729..65955296 100644 --- a/backend/api/middleware/request_inspector.go +++ b/backend/api/middleware/request_inspector.go @@ -61,9 +61,13 @@ func isStaticFile(ctx *app.RequestContext) bool { return true } - if strings.HasPrefix(string(path), "/static/") || - strings.HasPrefix(string(path), "/explore/") || - strings.HasPrefix(string(path), "/space/") { + if strings.HasPrefix(path, "/static/") || + strings.HasPrefix(path, "/explore/") || + strings.HasPrefix(path, "/space/") { + return true + } + + if path == "/information/auth/success" { return true } diff --git a/backend/api/model/crossdomain/plugin/plugin_manifest.go b/backend/api/model/crossdomain/plugin/plugin_manifest.go index 057fa53e..616faeac 100644 --- a/backend/api/model/crossdomain/plugin/plugin_manifest.go +++ b/backend/api/model/crossdomain/plugin/plugin_manifest.go @@ -23,7 +23,7 @@ import ( "strings" api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" - "github.com/coze-dev/coze-studio/backend/domain/plugin/utils" + "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/types/errno" @@ -75,12 +75,12 @@ func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) { return mf_, nil } - secret := os.Getenv(utils.AuthSecretEnv) + secret := os.Getenv(encrypt.AuthSecretEnv) if secret == "" { - secret = utils.DefaultAuthSecret + secret = encrypt.DefaultAuthSecret } - payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), secret) + payload_, err := encrypt.EncryptByAES([]byte(mf_.Auth.Payload), secret) if err != nil { return nil, err } @@ -363,12 +363,12 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error { } if auth.Payload != "" { - secret := os.Getenv(utils.AuthSecretEnv) + secret := os.Getenv(encrypt.AuthSecretEnv) if secret == "" { - secret = utils.DefaultAuthSecret + secret = encrypt.DefaultAuthSecret } - payload_, err := utils.DecryptByAES(auth.Payload, secret) + payload_, err := encrypt.DecryptByAES(auth.Payload, secret) if err == nil { auth.Payload = string(payload_) } diff --git a/backend/application/plugin/plugin.go b/backend/application/plugin/plugin.go index 3ea77217..da8e3120 100644 --- a/backend/application/plugin/plugin.go +++ b/backend/application/plugin/plugin.go @@ -45,10 +45,10 @@ import ( "github.com/coze-dev/coze-studio/backend/application/base/pluginutil" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch" pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/repository" "github.com/coze-dev/coze-studio/backend/domain/plugin/service" - "github.com/coze-dev/coze-studio/backend/domain/plugin/utils" searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity" search "github.com/coze-dev/coze-studio/backend/domain/search/service" user "github.com/coze-dev/coze-studio/backend/domain/user/service" @@ -1704,12 +1704,12 @@ func (p *PluginApplicationService) OauthAuthorizationCode(ctx context.Context, r return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) } - secret := os.Getenv(utils.StateSecretEnv) + secret := os.Getenv(encrypt.StateSecretEnv) if secret == "" { - secret = utils.DefaultStateSecret + secret = encrypt.DefaultStateSecret } - stateBytes, err := utils.DecryptByAES(stateStr, secret) + stateBytes, err := encrypt.DecryptByAES(stateStr, secret) if err != nil { return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state")) } diff --git a/backend/domain/plugin/conf/host.go b/backend/domain/plugin/conf/host.go new file mode 100644 index 00000000..add4dc2c --- /dev/null +++ b/backend/domain/plugin/conf/host.go @@ -0,0 +1,30 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package conf + +import ( + "os" + "strings" +) + +func GetServerHost() string { + host := os.Getenv("SERVER_HOST") + if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") { + return host + } + return "https://" + host +} diff --git a/backend/domain/plugin/utils/aes.go b/backend/domain/plugin/encrypt/aes.go similarity index 99% rename from backend/domain/plugin/utils/aes.go rename to backend/domain/plugin/encrypt/aes.go index 1e7cd898..ddb7122c 100644 --- a/backend/domain/plugin/utils/aes.go +++ b/backend/domain/plugin/encrypt/aes.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package utils +package encrypt import ( "bytes" diff --git a/backend/domain/plugin/utils/aes_test.go b/backend/domain/plugin/encrypt/aes_test.go similarity index 98% rename from backend/domain/plugin/utils/aes_test.go rename to backend/domain/plugin/encrypt/aes_test.go index bc2c2395..abcb5add 100644 --- a/backend/domain/plugin/utils/aes_test.go +++ b/backend/domain/plugin/encrypt/aes_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package utils +package encrypt import ( "testing" diff --git a/backend/domain/plugin/entity/consts.go b/backend/domain/plugin/entity/consts.go index ee523ebb..c7bb9215 100644 --- a/backend/domain/plugin/entity/consts.go +++ b/backend/domain/plugin/entity/consts.go @@ -16,25 +16,6 @@ package entity -import ( - "strings" -) - -const ( - larkPluginOAuthHostName = "open.larkoffice.com" - larkOAuthHostName = "open.feishu.cn" -) - -func GetOAuthProvider(tokenURL string) OAuthProvider { - if strings.Contains(tokenURL, larkPluginOAuthHostName) { - return OAuthProviderOfLarkPlugin - } - if strings.Contains(tokenURL, larkOAuthHostName) { - return OAuthProviderOfLark - } - return OAuthProviderOfStandard -} - type SortField string const ( @@ -43,9 +24,3 @@ const ( ) type OAuthProvider string - -const ( - OAuthProviderOfLarkPlugin OAuthProvider = "lark_plugin" - OAuthProviderOfLark OAuthProvider = "lark" - OAuthProviderOfStandard OAuthProvider = "standard" -) diff --git a/backend/domain/plugin/internal/dal/plugin_oauth_auth.go b/backend/domain/plugin/internal/dal/plugin_oauth_auth.go index 4b3ca003..72cc78cc 100644 --- a/backend/domain/plugin/internal/dal/plugin_oauth_auth.go +++ b/backend/domain/plugin/internal/dal/plugin_oauth_auth.go @@ -25,10 +25,10 @@ import ( "gorm.io/gorm" + "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query" - "github.com/coze-dev/coze-studio/backend/domain/plugin/utils" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" ) @@ -43,19 +43,19 @@ func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAut type pluginOAuthAuthPO model.PluginOauthAuth func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo { - secret := os.Getenv(utils.OAuthTokenSecretEnv) + secret := os.Getenv(encrypt.OAuthTokenSecretEnv) if secret == "" { - secret = utils.DefaultOAuthTokenSecret + secret = encrypt.DefaultOAuthTokenSecret } if p.RefreshToken != "" { - refreshToken, err := utils.DecryptByAES(p.RefreshToken, secret) + refreshToken, err := encrypt.DecryptByAES(p.RefreshToken, secret) if err == nil { p.RefreshToken = string(refreshToken) } } if p.AccessToken != "" { - accessToken, err := utils.DecryptByAES(p.AccessToken, secret) + accessToken, err := encrypt.DecryptByAES(p.AccessToken, secret) if err == nil { p.AccessToken = string(accessToken) } @@ -109,20 +109,20 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat } meta := info.Meta - secret := os.Getenv(utils.OAuthTokenSecretEnv) + secret := os.Getenv(encrypt.OAuthTokenSecretEnv) if secret == "" { - secret = utils.DefaultOAuthTokenSecret + secret = encrypt.DefaultOAuthTokenSecret } var accessToken, refreshToken string if info.AccessToken != "" { - accessToken, err = utils.EncryptByAES([]byte(info.AccessToken), secret) + accessToken, err = encrypt.EncryptByAES([]byte(info.AccessToken), secret) if err != nil { return err } } if info.RefreshToken != "" { - refreshToken, err = utils.EncryptByAES([]byte(info.RefreshToken), secret) + refreshToken, err = encrypt.EncryptByAES([]byte(info.RefreshToken), secret) if err != nil { return err } diff --git a/backend/domain/plugin/service/plugin_oauth.go b/backend/domain/plugin/service/plugin_oauth.go index 0be856bc..6e756a0a 100644 --- a/backend/domain/plugin/service/plugin_oauth.go +++ b/backend/domain/plugin/service/plugin_oauth.go @@ -29,8 +29,9 @@ import ( model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" + "github.com/coze-dev/coze-studio/backend/domain/plugin/conf" + "github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" - "github.com/coze-dev/coze-studio/backend/domain/plugin/utils" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -438,12 +439,12 @@ func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) { return "", fmt.Errorf("marshal state failed, err=%v", err) } - secret := os.Getenv(utils.StateSecretEnv) + secret := os.Getenv(encrypt.StateSecretEnv) if secret == "" { - secret = utils.DefaultStateSecret + secret = encrypt.DefaultStateSecret } - encryptState, err := utils.EncryptByAES(stateStr, secret) + encryptState, err := encrypt.EncryptByAES(stateStr, secret) if err != nil { return "", fmt.Errorf("encrypt state failed, err=%v", err) } @@ -464,7 +465,7 @@ func getStanderOAuthConfig(config *model.OAuthAuthorizationCodeConfig) *oauth2.C TokenURL: config.AuthorizationURL, AuthURL: config.ClientURL, }, - RedirectURL: fmt.Sprintf("https://%s/api/oauth/authorization_code", os.Getenv("SERVER_HOST")), + RedirectURL: fmt.Sprintf("%s/api/oauth/authorization_code", conf.GetServerHost()), Scopes: strings.Split(config.Scope, " "), } } diff --git a/docker/.env.debug.example b/docker/.env.debug.example index cfe0af9f..aff24f6b 100644 --- a/docker/.env.debug.example +++ b/docker/.env.debug.example @@ -5,7 +5,7 @@ export RUN_MODE="debug" # Currently supports debug mode. When set to debug, it h export LISTEN_ADDR=":8888" export LOG_LEVEL="debug" export MAX_REQUEST_BODY_SIZE=1073741824 -export SERVER_HOST="localhost${LISTEN_ADDR}" +export SERVER_HOST="http://localhost${LISTEN_ADDR}" export MINIO_PROXY_ENDPOINT="" export USE_SSL="0" export SSL_CERT_FILE="" diff --git a/docker/.env.example b/docker/.env.example index 32b66637..d3c23206 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -2,8 +2,8 @@ export LISTEN_ADDR=":8888" export LOG_LEVEL="debug" export MAX_REQUEST_BODY_SIZE=1073741824 -export SERVER_HOST="localhost${LISTEN_ADDR}" -export MINIO_PROXY_ENDPOINT=":8889" +export SERVER_HOST="http://localhost${LISTEN_ADDR}" +export MINIO_PROXY_ENDPOINT="" export USE_SSL="0" export SSL_CERT_FILE="" export SSL_KEY_FILE=""