feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
497
backend/api/model/crossdomain/plugin/plugin_manifest.go
Normal file
497
backend/api/model/crossdomain/plugin/plugin_manifest.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
type PluginManifest struct {
|
||||
SchemaVersion string `json:"schema_version" yaml:"schema_version"`
|
||||
NameForModel string `json:"name_for_model" yaml:"name_for_model"`
|
||||
NameForHuman string `json:"name_for_human" yaml:"name_for_human"`
|
||||
DescriptionForModel string `json:"description_for_model" yaml:"description_for_model"`
|
||||
DescriptionForHuman string `json:"description_for_human" yaml:"description_for_human"`
|
||||
Auth *AuthV2 `json:"auth" yaml:"auth"`
|
||||
LogoURL string `json:"logo_url" yaml:"logo_url"`
|
||||
API APIDesc `json:"api" yaml:"api"`
|
||||
CommonParams map[HTTPParamLocation][]*api.CommonParamSchema `json:"common_params" yaml:"common_params"`
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) Copy() (*PluginManifest, error) {
|
||||
if mf == nil {
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
b, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf_ := &PluginManifest{}
|
||||
err = json.Unmarshal(b, mf_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mf_, err
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
|
||||
if mf == nil || mf.Auth == nil {
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
mf_, err := mf.Copy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mf_.Auth.Payload == "" {
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), utils.AuthSecretKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf_.Auth.Payload = payload_
|
||||
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) Validate(skipAuthPayload bool) (err error) {
|
||||
if mf.SchemaVersion != "v1" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid schema version '%s'", mf.SchemaVersion))
|
||||
}
|
||||
if mf.NameForModel == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"name for model is required"))
|
||||
}
|
||||
if mf.NameForHuman == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"name for human is required"))
|
||||
}
|
||||
if mf.DescriptionForModel == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"description for model is required"))
|
||||
}
|
||||
if mf.DescriptionForHuman == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"description for human is required"))
|
||||
}
|
||||
if mf.API.Type != PluginTypeOfCloud {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid api type '%s'", mf.API.Type))
|
||||
}
|
||||
|
||||
err = mf.validateAuthInfo(skipAuthPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for loc := range mf.CommonParams {
|
||||
if loc != ParamInBody &&
|
||||
loc != ParamInHeader &&
|
||||
loc != ParamInQuery &&
|
||||
loc != ParamInPath {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid location '%s' in common params", loc))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateAuthInfo(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"auth is required"))
|
||||
}
|
||||
|
||||
if mf.Auth.Payload != "" {
|
||||
js := json.RawMessage{}
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &js)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if mf.Auth.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"auth type is required"))
|
||||
}
|
||||
|
||||
if mf.Auth.Type != AuthzTypeOfNone &&
|
||||
mf.Auth.Type != AuthzTypeOfOAuth &&
|
||||
mf.Auth.Type != AuthzTypeOfService {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid auth type '%s'", mf.Auth.Type))
|
||||
}
|
||||
|
||||
if mf.Auth.Type == AuthzTypeOfNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
if mf.Auth.SubType == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"sub-auth type is required"))
|
||||
}
|
||||
|
||||
switch mf.Auth.SubType {
|
||||
case AuthzSubTypeOfServiceAPIToken:
|
||||
err = mf.validateServiceToken(skipAuthPayload)
|
||||
//case AuthzSubTypeOfOAuthClientCredentials:
|
||||
// err = mf.validateClientCredentials()
|
||||
case AuthzSubTypeOfOAuthAuthorizationCode:
|
||||
err = mf.validateAuthCode(skipAuthPayload)
|
||||
default:
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid sub-auth type '%s'", mf.Auth.SubType))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateServiceToken(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth.AuthOfAPIToken == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfAPIToken)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if skipAuthPayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiToken := mf.Auth.AuthOfAPIToken
|
||||
|
||||
if apiToken.ServiceToken == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"service token is required"))
|
||||
}
|
||||
if apiToken.Key == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"key is required"))
|
||||
}
|
||||
|
||||
loc := HTTPParamLocation(strings.ToLower(string(apiToken.Location)))
|
||||
if loc != ParamInHeader && loc != ParamInQuery {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid location '%s'", apiToken.Location))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateClientCredentials() (err error) {
|
||||
if mf.Auth.AuthOfOAuthClientCredentials == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthClientCredentials)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
clientCredentials := mf.Auth.AuthOfOAuthClientCredentials
|
||||
|
||||
if clientCredentials.ClientID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client id is required"))
|
||||
}
|
||||
if clientCredentials.ClientSecret == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client secret is required"))
|
||||
}
|
||||
if clientCredentials.TokenURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url is required"))
|
||||
}
|
||||
|
||||
urlParse, err := url.Parse(clientCredentials.TokenURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid token url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateAuthCode(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth.AuthOfOAuthAuthorizationCode == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthAuthorizationCode)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if skipAuthPayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
authCode := mf.Auth.AuthOfOAuthAuthorizationCode
|
||||
|
||||
if authCode.ClientID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client id is required"))
|
||||
}
|
||||
if authCode.ClientSecret == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client secret is required"))
|
||||
}
|
||||
if authCode.AuthorizationContentType != MediaTypeJson {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"authorization content type must be 'application/json'"))
|
||||
}
|
||||
if authCode.AuthorizationURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url is required"))
|
||||
}
|
||||
if authCode.ClientURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client url is required"))
|
||||
}
|
||||
|
||||
urlParse, err := url.Parse(authCode.AuthorizationURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid authorization url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"authorization url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
urlParse, err = url.Parse(authCode.ClientURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid client url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
Type string `json:"type" validate:"required"`
|
||||
AuthorizationType string `json:"authorization_type,omitempty"`
|
||||
ClientURL string `json:"client_url,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
AuthorizationURL string `json:"authorization_url,omitempty"`
|
||||
AuthorizationContentType string `json:"authorization_content_type,omitempty"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
Location string `json:"location,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
ServiceToken string `json:"service_token,omitempty"`
|
||||
SubType string `json:"sub_type"`
|
||||
Payload string `json:"payload"`
|
||||
}
|
||||
|
||||
type AuthV2 struct {
|
||||
Type AuthzType `json:"type" yaml:"type"`
|
||||
SubType AuthzSubType `json:"sub_type" yaml:"sub_type"`
|
||||
Payload string `json:"payload" yaml:"payload"`
|
||||
// service
|
||||
AuthOfAPIToken *AuthOfAPIToken `json:"-"`
|
||||
|
||||
// oauth
|
||||
AuthOfOAuthAuthorizationCode *OAuthAuthorizationCodeConfig `json:"-"`
|
||||
AuthOfOAuthClientCredentials *OAuthClientCredentialsConfig `json:"-"`
|
||||
}
|
||||
|
||||
func (au *AuthV2) UnmarshalJSON(data []byte) error {
|
||||
auth := &Auth{} // 兼容老数据
|
||||
err := json.Unmarshal(data, auth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid plugin manifest json"))
|
||||
}
|
||||
|
||||
au.Type = AuthzType(auth.Type)
|
||||
au.SubType = AuthzSubType(auth.SubType)
|
||||
|
||||
if au.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"plugin auth type is required"))
|
||||
}
|
||||
|
||||
if auth.Payload != "" {
|
||||
payload_, err := utils.DecryptByAES(auth.Payload, utils.AuthSecretKey)
|
||||
if err == nil {
|
||||
auth.Payload = string(payload_)
|
||||
}
|
||||
}
|
||||
|
||||
switch au.Type {
|
||||
case AuthzTypeOfNone:
|
||||
case AuthzTypeOfOAuth:
|
||||
err = au.unmarshalOAuth(auth)
|
||||
case AuthzTypeOfService:
|
||||
err = au.unmarshalService(auth)
|
||||
default:
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin auth type '%s'", au.Type))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (au *AuthV2) unmarshalService(auth *Auth) (err error) {
|
||||
if au.SubType == "" && au.Payload == "" { // 兼容老数据
|
||||
au.SubType = AuthzSubTypeOfServiceAPIToken
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
|
||||
if au.SubType == AuthzSubTypeOfServiceAPIToken {
|
||||
if len(auth.ServiceToken) > 0 {
|
||||
au.AuthOfAPIToken = &AuthOfAPIToken{
|
||||
Location: HTTPParamLocation(strings.ToLower(auth.Location)),
|
||||
Key: auth.Key,
|
||||
ServiceToken: auth.ServiceToken,
|
||||
}
|
||||
} else {
|
||||
token := &AuthOfAPIToken{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), token)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfAPIToken = token
|
||||
}
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfAPIToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin sub-auth type '%s'", au.SubType))
|
||||
}
|
||||
|
||||
au.Payload = string(payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) {
|
||||
if au.SubType == "" { // 兼容老数据
|
||||
au.SubType = AuthzSubTypeOfOAuthAuthorizationCode
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
|
||||
if au.SubType == AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
if len(auth.ClientSecret) > 0 {
|
||||
au.AuthOfOAuthAuthorizationCode = &OAuthAuthorizationCodeConfig{
|
||||
ClientID: auth.ClientID,
|
||||
ClientSecret: auth.ClientSecret,
|
||||
ClientURL: auth.ClientURL,
|
||||
Scope: auth.Scope,
|
||||
AuthorizationURL: auth.AuthorizationURL,
|
||||
AuthorizationContentType: auth.AuthorizationContentType,
|
||||
}
|
||||
} else {
|
||||
oauth := &OAuthAuthorizationCodeConfig{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), oauth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfOAuthAuthorizationCode = oauth
|
||||
}
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfOAuthAuthorizationCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if au.SubType == AuthzSubTypeOfOAuthClientCredentials {
|
||||
oauth := &OAuthClientCredentialsConfig{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), oauth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfOAuthClientCredentials = oauth
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfOAuthClientCredentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin sub-auth type '%s'", au.SubType))
|
||||
}
|
||||
|
||||
au.Payload = string(payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthOfAPIToken struct {
|
||||
// Location is the location of the parameter.
|
||||
// It can be "header" or "query".
|
||||
Location HTTPParamLocation `json:"location"`
|
||||
// Key is the name of the parameter.
|
||||
Key string `json:"key"`
|
||||
// ServiceToken is the simple authorization information for the service.
|
||||
ServiceToken string `json:"service_token"`
|
||||
}
|
||||
|
||||
type OAuthAuthorizationCodeConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
// ClientURL is the URL of authorization endpoint.
|
||||
ClientURL string `json:"client_url"`
|
||||
// Scope is the scope of the authorization request.
|
||||
// If multiple scopes are requested, they must be separated by a space.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// AuthorizationURL is the URL of token exchange endpoint.
|
||||
AuthorizationURL string `json:"authorization_url"`
|
||||
// AuthorizationContentType is the content type of the authorization request, and it must be "application/json".
|
||||
AuthorizationContentType string `json:"authorization_content_type"`
|
||||
}
|
||||
|
||||
type OAuthClientCredentialsConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
TokenURL string `json:"token_url"`
|
||||
}
|
||||
|
||||
type APIDesc struct {
|
||||
Type PluginType `json:"type" validate:"required"`
|
||||
}
|
||||
Reference in New Issue
Block a user