feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
110
backend/api/model/crossdomain/plugin/consts.go
Normal file
110
backend/api/model/crossdomain/plugin/consts.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package plugin
|
||||
|
||||
import "github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
type PluginType string
|
||||
|
||||
const (
|
||||
PluginTypeOfCloud PluginType = "openapi"
|
||||
)
|
||||
|
||||
type AuthzType string
|
||||
|
||||
const (
|
||||
AuthzTypeOfNone AuthzType = "none"
|
||||
AuthzTypeOfService AuthzType = "service_http"
|
||||
AuthzTypeOfOAuth AuthzType = "oauth"
|
||||
)
|
||||
|
||||
type AuthzSubType string
|
||||
|
||||
const (
|
||||
AuthzSubTypeOfServiceAPIToken AuthzSubType = "token/api_key"
|
||||
AuthzSubTypeOfOAuthAuthorizationCode AuthzSubType = "authorization_code"
|
||||
AuthzSubTypeOfOAuthClientCredentials AuthzSubType = "client_credentials"
|
||||
)
|
||||
|
||||
type HTTPParamLocation string
|
||||
|
||||
const (
|
||||
ParamInHeader HTTPParamLocation = openapi3.ParameterInHeader
|
||||
ParamInPath HTTPParamLocation = openapi3.ParameterInPath
|
||||
ParamInQuery HTTPParamLocation = openapi3.ParameterInQuery
|
||||
ParamInBody HTTPParamLocation = "body"
|
||||
)
|
||||
|
||||
type ActivatedStatus int32
|
||||
|
||||
const (
|
||||
ActivateTool ActivatedStatus = 0
|
||||
DeactivateTool ActivatedStatus = 1
|
||||
)
|
||||
|
||||
type ProjectType int8
|
||||
|
||||
const (
|
||||
ProjectTypeOfAgent ProjectType = 1
|
||||
ProjectTypeOfAPP ProjectType = 2
|
||||
)
|
||||
|
||||
type ExecuteScene string
|
||||
|
||||
const (
|
||||
ExecSceneOfOnlineAgent ExecuteScene = "online_agent"
|
||||
ExecSceneOfDraftAgent ExecuteScene = "draft_agent"
|
||||
ExecSceneOfWorkflow ExecuteScene = "workflow"
|
||||
ExecSceneOfToolDebug ExecuteScene = "tool_debug"
|
||||
)
|
||||
|
||||
type InvalidResponseProcessStrategy int8
|
||||
|
||||
const (
|
||||
InvalidResponseProcessStrategyOfReturnRaw InvalidResponseProcessStrategy = 0 // If the value of a field is invalid, the raw response value of the field is returned.
|
||||
InvalidResponseProcessStrategyOfReturnDefault InvalidResponseProcessStrategy = 1 // If the value of a field is invalid, the default value of the field is returned.
|
||||
)
|
||||
|
||||
const (
|
||||
APISchemaExtendAssistType = "x-assist-type"
|
||||
APISchemaExtendGlobalDisable = "x-global-disable"
|
||||
APISchemaExtendLocalDisable = "x-local-disable"
|
||||
APISchemaExtendVariableRef = "x-variable-ref"
|
||||
APISchemaExtendAuthMode = "x-auth-mode"
|
||||
)
|
||||
|
||||
type ToolAuthMode string
|
||||
|
||||
const (
|
||||
ToolAuthModeOfRequired ToolAuthMode = "required"
|
||||
ToolAuthModeOfSupported ToolAuthMode = "supported"
|
||||
ToolAuthModeOfDisabled ToolAuthMode = "disabled"
|
||||
)
|
||||
|
||||
type APIFileAssistType string
|
||||
|
||||
const (
|
||||
AssistTypeFile APIFileAssistType = "file"
|
||||
AssistTypeImage APIFileAssistType = "image"
|
||||
AssistTypeDoc APIFileAssistType = "doc"
|
||||
AssistTypePPT APIFileAssistType = "ppt"
|
||||
AssistTypeCode APIFileAssistType = "code"
|
||||
AssistTypeExcel APIFileAssistType = "excel"
|
||||
AssistTypeZIP APIFileAssistType = "zip"
|
||||
AssistTypeVideo APIFileAssistType = "video"
|
||||
AssistTypeAudio APIFileAssistType = "audio"
|
||||
AssistTypeTXT APIFileAssistType = "txt"
|
||||
)
|
||||
|
||||
type CopyScene string
|
||||
|
||||
const (
|
||||
CopySceneOfToAPP CopyScene = "to_app"
|
||||
CopySceneOfToLibrary CopyScene = "to_library"
|
||||
CopySceneOfDuplicate CopyScene = "duplicate"
|
||||
CopySceneOfAPPDuplicate CopyScene = "app_duplicate"
|
||||
)
|
||||
|
||||
type InterruptEventType string
|
||||
|
||||
const (
|
||||
InterruptEventTypeOfToolNeedOAuth InterruptEventType = "tool_need_oauth"
|
||||
)
|
||||
270
backend/api/model/crossdomain/plugin/convert.go
Normal file
270
backend/api/model/crossdomain/plugin/convert.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
)
|
||||
|
||||
var httpParamLocations = map[common.ParameterLocation]HTTPParamLocation{
|
||||
common.ParameterLocation_Path: ParamInPath,
|
||||
common.ParameterLocation_Query: ParamInQuery,
|
||||
common.ParameterLocation_Body: ParamInBody,
|
||||
common.ParameterLocation_Header: ParamInHeader,
|
||||
}
|
||||
|
||||
func ToHTTPParamLocation(loc common.ParameterLocation) (HTTPParamLocation, bool) {
|
||||
_loc, ok := httpParamLocations[loc]
|
||||
return _loc, ok
|
||||
}
|
||||
|
||||
var thriftHTTPParamLocations = func() map[HTTPParamLocation]common.ParameterLocation {
|
||||
locations := make(map[HTTPParamLocation]common.ParameterLocation, len(httpParamLocations))
|
||||
for k, v := range httpParamLocations {
|
||||
locations[v] = k
|
||||
}
|
||||
return locations
|
||||
}()
|
||||
|
||||
func ToThriftHTTPParamLocation(loc HTTPParamLocation) (common.ParameterLocation, bool) {
|
||||
_loc, ok := thriftHTTPParamLocations[loc]
|
||||
return _loc, ok
|
||||
}
|
||||
|
||||
var openapiTypes = map[common.ParameterType]string{
|
||||
common.ParameterType_String: openapi3.TypeString,
|
||||
common.ParameterType_Integer: openapi3.TypeInteger,
|
||||
common.ParameterType_Number: openapi3.TypeNumber,
|
||||
common.ParameterType_Object: openapi3.TypeObject,
|
||||
common.ParameterType_Array: openapi3.TypeArray,
|
||||
common.ParameterType_Bool: openapi3.TypeBoolean,
|
||||
}
|
||||
|
||||
func ToOpenapiParamType(typ common.ParameterType) (string, bool) {
|
||||
_typ, ok := openapiTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
var thriftParameterTypes = func() map[string]common.ParameterType {
|
||||
types := make(map[string]common.ParameterType, len(openapiTypes))
|
||||
for k, v := range openapiTypes {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftParamType(typ string) (common.ParameterType, bool) {
|
||||
_typ, ok := thriftParameterTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
var apiAssistTypes = map[common.AssistParameterType]APIFileAssistType{
|
||||
common.AssistParameterType_DEFAULT: AssistTypeFile,
|
||||
common.AssistParameterType_IMAGE: AssistTypeImage,
|
||||
common.AssistParameterType_DOC: AssistTypeDoc,
|
||||
common.AssistParameterType_PPT: AssistTypePPT,
|
||||
common.AssistParameterType_CODE: AssistTypeCode,
|
||||
common.AssistParameterType_EXCEL: AssistTypeExcel,
|
||||
common.AssistParameterType_ZIP: AssistTypeZIP,
|
||||
common.AssistParameterType_VIDEO: AssistTypeVideo,
|
||||
common.AssistParameterType_AUDIO: AssistTypeAudio,
|
||||
common.AssistParameterType_TXT: AssistTypeTXT,
|
||||
}
|
||||
|
||||
func ToAPIAssistType(typ common.AssistParameterType) (APIFileAssistType, bool) {
|
||||
_typ, ok := apiAssistTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
var thriftAPIAssistTypes = func() map[APIFileAssistType]common.AssistParameterType {
|
||||
types := make(map[APIFileAssistType]common.AssistParameterType, len(apiAssistTypes))
|
||||
for k, v := range apiAssistTypes {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftAPIAssistType(typ APIFileAssistType) (common.AssistParameterType, bool) {
|
||||
_typ, ok := thriftAPIAssistTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
func IsValidAPIAssistType(typ APIFileAssistType) bool {
|
||||
_, ok := thriftAPIAssistTypes[typ]
|
||||
return ok
|
||||
}
|
||||
|
||||
var httpMethods = map[common.APIMethod]string{
|
||||
common.APIMethod_GET: http.MethodGet,
|
||||
common.APIMethod_POST: http.MethodPost,
|
||||
common.APIMethod_PUT: http.MethodPut,
|
||||
common.APIMethod_DELETE: http.MethodDelete,
|
||||
common.APIMethod_PATCH: http.MethodPatch,
|
||||
}
|
||||
|
||||
var thriftAPIMethods = func() map[string]common.APIMethod {
|
||||
methods := make(map[string]common.APIMethod, len(httpMethods))
|
||||
for k, v := range httpMethods {
|
||||
methods[v] = k
|
||||
}
|
||||
return methods
|
||||
}()
|
||||
|
||||
func ToThriftAPIMethod(method string) (common.APIMethod, bool) {
|
||||
_method, ok := thriftAPIMethods[method]
|
||||
return _method, ok
|
||||
}
|
||||
|
||||
func ToHTTPMethod(method common.APIMethod) (string, bool) {
|
||||
_method, ok := httpMethods[method]
|
||||
return _method, ok
|
||||
}
|
||||
|
||||
var assistTypeToFormat = map[APIFileAssistType]string{
|
||||
AssistTypeFile: "file_url",
|
||||
AssistTypeImage: "image_url",
|
||||
AssistTypeDoc: "doc_url",
|
||||
AssistTypePPT: "ppt_url",
|
||||
AssistTypeCode: "code_url",
|
||||
AssistTypeExcel: "excel_url",
|
||||
AssistTypeZIP: "zip_url",
|
||||
AssistTypeVideo: "video_url",
|
||||
AssistTypeAudio: "audio_url",
|
||||
AssistTypeTXT: "txt_url",
|
||||
}
|
||||
|
||||
func AssistTypeToFormat(typ APIFileAssistType) (string, bool) {
|
||||
format, ok := assistTypeToFormat[typ]
|
||||
return format, ok
|
||||
}
|
||||
|
||||
var formatToAssistType = func() map[string]APIFileAssistType {
|
||||
types := make(map[string]APIFileAssistType, len(assistTypeToFormat))
|
||||
for k, v := range assistTypeToFormat {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func FormatToAssistType(format string) (APIFileAssistType, bool) {
|
||||
typ, ok := formatToAssistType[format]
|
||||
return typ, ok
|
||||
}
|
||||
|
||||
var assistTypeToThriftFormat = map[APIFileAssistType]common.PluginParamTypeFormat{
|
||||
AssistTypeFile: common.PluginParamTypeFormat_FileUrl,
|
||||
AssistTypeImage: common.PluginParamTypeFormat_ImageUrl,
|
||||
AssistTypeDoc: common.PluginParamTypeFormat_DocUrl,
|
||||
AssistTypePPT: common.PluginParamTypeFormat_PptUrl,
|
||||
AssistTypeCode: common.PluginParamTypeFormat_CodeUrl,
|
||||
AssistTypeExcel: common.PluginParamTypeFormat_ExcelUrl,
|
||||
AssistTypeZIP: common.PluginParamTypeFormat_ZipUrl,
|
||||
AssistTypeVideo: common.PluginParamTypeFormat_VideoUrl,
|
||||
AssistTypeAudio: common.PluginParamTypeFormat_AudioUrl,
|
||||
AssistTypeTXT: common.PluginParamTypeFormat_TxtUrl,
|
||||
}
|
||||
|
||||
func AssistTypeToThriftFormat(typ APIFileAssistType) (common.PluginParamTypeFormat, bool) {
|
||||
format, ok := assistTypeToThriftFormat[typ]
|
||||
return format, ok
|
||||
}
|
||||
|
||||
var authTypes = map[common.AuthorizationType]AuthzType{
|
||||
common.AuthorizationType_None: AuthzTypeOfNone,
|
||||
common.AuthorizationType_Service: AuthzTypeOfService,
|
||||
common.AuthorizationType_OAuth: AuthzTypeOfOAuth,
|
||||
common.AuthorizationType_Standard: AuthzTypeOfOAuth, // deprecated, the same as OAuth
|
||||
}
|
||||
|
||||
func ToAuthType(typ common.AuthorizationType) (AuthzType, bool) {
|
||||
_type, ok := authTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var thriftAuthTypes = func() map[AuthzType]common.AuthorizationType {
|
||||
types := make(map[AuthzType]common.AuthorizationType, len(authTypes))
|
||||
for k, v := range authTypes {
|
||||
if v == AuthzTypeOfOAuth {
|
||||
types[v] = common.AuthorizationType_OAuth
|
||||
} else {
|
||||
types[v] = k
|
||||
}
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftAuthType(typ AuthzType) (common.AuthorizationType, bool) {
|
||||
_type, ok := thriftAuthTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var subAuthTypes = map[int32]AuthzSubType{
|
||||
int32(common.ServiceAuthSubType_ApiKey): AuthzSubTypeOfServiceAPIToken,
|
||||
int32(common.ServiceAuthSubType_OAuthAuthorizationCode): AuthzSubTypeOfOAuthAuthorizationCode,
|
||||
}
|
||||
|
||||
func ToAuthSubType(typ int32) (AuthzSubType, bool) {
|
||||
_type, ok := subAuthTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var thriftSubAuthTypes = func() map[AuthzSubType]int32 {
|
||||
types := make(map[AuthzSubType]int32, len(subAuthTypes))
|
||||
for k, v := range subAuthTypes {
|
||||
types[v] = int32(k)
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftAuthSubType(typ AuthzSubType) (int32, bool) {
|
||||
_type, ok := thriftSubAuthTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var pluginTypes = map[common.PluginType]PluginType{
|
||||
common.PluginType_PLUGIN: PluginTypeOfCloud,
|
||||
}
|
||||
|
||||
func ToPluginType(typ common.PluginType) (PluginType, bool) {
|
||||
_type, ok := pluginTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var thriftPluginTypes = func() map[PluginType]common.PluginType {
|
||||
types := make(map[PluginType]common.PluginType, len(pluginTypes))
|
||||
for k, v := range pluginTypes {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftPluginType(typ PluginType) (common.PluginType, bool) {
|
||||
_type, ok := thriftPluginTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var apiAuthModes = map[common.PluginToolAuthType]ToolAuthMode{
|
||||
common.PluginToolAuthType_Required: ToolAuthModeOfRequired,
|
||||
common.PluginToolAuthType_Supported: ToolAuthModeOfSupported,
|
||||
common.PluginToolAuthType_Disable: ToolAuthModeOfDisabled,
|
||||
}
|
||||
|
||||
func ToAPIAuthMode(mode common.PluginToolAuthType) (ToolAuthMode, bool) {
|
||||
_mode, ok := apiAuthModes[mode]
|
||||
return _mode, ok
|
||||
}
|
||||
|
||||
var thriftAPIAuthModes = func() map[ToolAuthMode]common.PluginToolAuthType {
|
||||
modes := make(map[ToolAuthMode]common.PluginToolAuthType, len(apiAuthModes))
|
||||
for k, v := range apiAuthModes {
|
||||
modes[v] = k
|
||||
}
|
||||
return modes
|
||||
}()
|
||||
|
||||
func ToThriftAPIAuthMode(mode ToolAuthMode) (common.PluginToolAuthType, bool) {
|
||||
_mode, ok := thriftAPIAuthModes[mode]
|
||||
return _mode, ok
|
||||
}
|
||||
429
backend/api/model/crossdomain/plugin/openai.go
Normal file
429
backend/api/model/crossdomain/plugin/openai.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type Openapi3T openapi3.T
|
||||
|
||||
func (ot Openapi3T) Validate(ctx context.Context) (err error) {
|
||||
err = ptr.Of(openapi3.T(ot)).Validate(ctx)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error()))
|
||||
}
|
||||
|
||||
if ot.Info == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"info is required"))
|
||||
}
|
||||
if ot.Info.Title == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"the title of info is required"))
|
||||
}
|
||||
if ot.Info.Description == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"the description of info is required"))
|
||||
}
|
||||
|
||||
if len(ot.Servers) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"server is required and only one server is allowed"))
|
||||
}
|
||||
|
||||
serverURL := ot.Servers[0].URL
|
||||
urlSchema, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid server url '%s'", serverURL))
|
||||
}
|
||||
if urlSchema.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"server url must start with 'https://'"))
|
||||
}
|
||||
if urlSchema.Host == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid server url '%s'", serverURL))
|
||||
}
|
||||
if len(serverURL) > 512 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"server url '%s' is too long", serverURL))
|
||||
}
|
||||
|
||||
for _, pathItem := range ot.Paths {
|
||||
for _, op := range pathItem.Operations() {
|
||||
err = NewOpenapi3Operation(op).Validate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewOpenapi3Operation(op *openapi3.Operation) *Openapi3Operation {
|
||||
return &Openapi3Operation{
|
||||
Operation: op,
|
||||
}
|
||||
}
|
||||
|
||||
type Openapi3Operation struct {
|
||||
*openapi3.Operation
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) MarshalJSON() ([]byte, error) {
|
||||
return op.Operation.MarshalJSON()
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) UnmarshalJSON(data []byte) error {
|
||||
op.Operation = &openapi3.Operation{}
|
||||
return op.Operation.UnmarshalJSON(data)
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) Validate(ctx context.Context) (err error) {
|
||||
err = op.Operation.Validate(ctx)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey, "operation is invalid, err=%s", err))
|
||||
}
|
||||
|
||||
if op.OperationID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "operationID is required"))
|
||||
}
|
||||
if op.Summary == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "summary is required"))
|
||||
}
|
||||
|
||||
err = validateOpenapi3RequestBody(op.RequestBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateOpenapi3Parameters(op.Parameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateOpenapi3Responses(op.Responses)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) ToEinoSchemaParameterInfo(ctx context.Context) (map[string]*schema.ParameterInfo, error) {
|
||||
convertType := func(openapiType string) schema.DataType {
|
||||
switch openapiType {
|
||||
case openapi3.TypeString:
|
||||
return schema.String
|
||||
case openapi3.TypeInteger:
|
||||
return schema.Integer
|
||||
case openapi3.TypeObject:
|
||||
return schema.Object
|
||||
case openapi3.TypeArray:
|
||||
return schema.Array
|
||||
case openapi3.TypeBoolean:
|
||||
return schema.Boolean
|
||||
case openapi3.TypeNumber:
|
||||
return schema.Number
|
||||
default:
|
||||
return schema.Null
|
||||
}
|
||||
}
|
||||
|
||||
var convertReqBody func(sc *openapi3.Schema, isRequired bool) (*schema.ParameterInfo, error)
|
||||
convertReqBody = func(sc *openapi3.Schema, isRequired bool) (*schema.ParameterInfo, error) {
|
||||
if disabledParam(sc) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
paramInfo := &schema.ParameterInfo{
|
||||
Type: convertType(sc.Type),
|
||||
Desc: sc.Description,
|
||||
Required: isRequired,
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
subParams := make(map[string]*schema.ParameterInfo, len(sc.Properties))
|
||||
for paramName, prop := range sc.Properties {
|
||||
subParam, err := convertReqBody(prop.Value, required[paramName])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subParams[paramName] = subParam
|
||||
}
|
||||
|
||||
paramInfo.SubParams = subParams
|
||||
|
||||
case openapi3.TypeArray:
|
||||
ele, err := convertReqBody(sc.Items.Value, isRequired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramInfo.ElemInfo = ele
|
||||
|
||||
case openapi3.TypeString, openapi3.TypeInteger, openapi3.TypeBoolean, openapi3.TypeNumber:
|
||||
return paramInfo, nil
|
||||
|
||||
default:
|
||||
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KVf(errno.PluginMsgKey,
|
||||
"unsupported json type '%s'", sc.Type))
|
||||
}
|
||||
|
||||
return paramInfo, nil
|
||||
}
|
||||
|
||||
result := make(map[string]*schema.ParameterInfo)
|
||||
|
||||
for _, prop := range op.Parameters {
|
||||
paramVal := prop.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
if schemaVal.Type == openapi3.TypeObject || schemaVal.Type == openapi3.TypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if disabledParam(prop.Value.Schema.Value) {
|
||||
continue
|
||||
}
|
||||
|
||||
paramInfo := &schema.ParameterInfo{
|
||||
Type: convertType(schemaVal.Type),
|
||||
Desc: paramVal.Description,
|
||||
Required: paramVal.Required,
|
||||
}
|
||||
|
||||
if _, ok := result[paramVal.Name]; ok {
|
||||
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramVal.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
result[paramVal.Name] = paramInfo
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for paramName, prop := range schemaVal.Properties {
|
||||
paramInfo, err := convertReqBody(prop.Value, required[paramName])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, ok := result[paramName]; ok {
|
||||
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramName)
|
||||
continue
|
||||
}
|
||||
|
||||
result[paramName] = paramInfo
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func validateOpenapi3RequestBody(bodyRef *openapi3.RequestBodyRef) (err error) {
|
||||
if bodyRef == nil {
|
||||
return nil
|
||||
}
|
||||
if bodyRef.Value == nil || len(bodyRef.Value.Content) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body is required"))
|
||||
}
|
||||
|
||||
body := bodyRef.Value
|
||||
if len(body.Content) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body only supports one media type"))
|
||||
}
|
||||
|
||||
var mType *openapi3.MediaType
|
||||
for _, ct := range mediaTypeArray {
|
||||
var ok bool
|
||||
mType, ok = body.Content[ct]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if mType == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid media type, request body only the following types: [%s]", strings.Join(mediaTypeArray, ", ")))
|
||||
}
|
||||
|
||||
if mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body schema is required"))
|
||||
}
|
||||
|
||||
sc := mType.Schema.Value
|
||||
if sc.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body only supports 'object' type"))
|
||||
}
|
||||
if sc.Type != openapi3.TypeObject {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body only supports 'object' type"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpenapi3Parameters(params openapi3.Parameters) (err error) {
|
||||
if len(params) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, param := range params {
|
||||
if param == nil || param.Value == nil || param.Value.Schema == nil || param.Value.Schema.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"parameter schema is required"))
|
||||
}
|
||||
|
||||
paramVal := param.Value
|
||||
|
||||
if paramVal.In == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"parameter location is required"))
|
||||
}
|
||||
if paramVal.In == string(ParamInBody) {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the location of parameter '%s' cannot be 'body'", paramVal.Name))
|
||||
}
|
||||
|
||||
paramSchema := paramVal.Schema.Value
|
||||
if paramSchema.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type of parameter '%s' is required", paramVal.Name))
|
||||
}
|
||||
if paramSchema.Type == openapi3.TypeObject {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type of parameter '%s' cannot be 'object'", paramVal.Name))
|
||||
}
|
||||
if paramVal.In == openapi3.ParameterInPath && paramSchema.Type == openapi3.TypeArray {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type of parameter '%s' cannot be 'array'", paramVal.Name))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MIME Type
|
||||
const (
|
||||
MediaTypeJson = "application/json"
|
||||
MediaTypeProblemJson = "application/problem+json"
|
||||
MediaTypeFormURLEncoded = "application/x-www-form-urlencoded"
|
||||
MediaTypeXYaml = "application/x-yaml"
|
||||
MediaTypeYaml = "application/yaml"
|
||||
)
|
||||
|
||||
var mediaTypeArray = []string{
|
||||
MediaTypeJson,
|
||||
MediaTypeProblemJson,
|
||||
MediaTypeFormURLEncoded,
|
||||
MediaTypeXYaml,
|
||||
MediaTypeYaml,
|
||||
}
|
||||
|
||||
func validateOpenapi3Responses(responses openapi3.Responses) (err error) {
|
||||
if len(responses) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// default status 不处理
|
||||
// 只处理 '200' status
|
||||
if len(responses) != 1 {
|
||||
if len(responses) != 2 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports '200' status"))
|
||||
} else if _, ok := responses["default"]; !ok {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports '200' status"))
|
||||
}
|
||||
}
|
||||
|
||||
resp, ok := responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok || resp == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports '200' status"))
|
||||
}
|
||||
if resp.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response schema is required"))
|
||||
}
|
||||
if len(resp.Value.Content) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports 'application/json' media type"))
|
||||
}
|
||||
mType, ok := resp.Value.Content[MediaTypeJson]
|
||||
if !ok || mType == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports 'application/json' media type"))
|
||||
|
||||
}
|
||||
if mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"the media type schema of response is required"))
|
||||
}
|
||||
|
||||
sc := mType.Schema.Value
|
||||
if sc.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response body only supports 'object' type"))
|
||||
}
|
||||
if sc.Type != openapi3.TypeObject {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response body only supports 'object' type"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func disabledParam(schemaVal *openapi3.Schema) bool {
|
||||
if len(schemaVal.Extensions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
globalDisable, localDisable := false, false
|
||||
if v, ok := schemaVal.Extensions[APISchemaExtendLocalDisable]; ok {
|
||||
localDisable = v.(bool)
|
||||
}
|
||||
|
||||
if v, ok := schemaVal.Extensions[APISchemaExtendGlobalDisable]; ok {
|
||||
globalDisable = v.(bool)
|
||||
}
|
||||
|
||||
return globalDisable || localDisable
|
||||
}
|
||||
51
backend/api/model/crossdomain/plugin/option.go
Normal file
51
backend/api/model/crossdomain/plugin/option.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package plugin
|
||||
|
||||
type ExecuteToolOption struct {
|
||||
ProjectInfo *ProjectInfo
|
||||
|
||||
AutoGenRespSchema bool
|
||||
|
||||
ToolVersion string
|
||||
Operation *Openapi3Operation
|
||||
InvalidRespProcessStrategy InvalidResponseProcessStrategy
|
||||
}
|
||||
|
||||
type ExecuteToolOpt func(o *ExecuteToolOption)
|
||||
|
||||
type ProjectInfo struct {
|
||||
ProjectID int64 // agentID or appID
|
||||
ProjectVersion *string // if version si nil, use latest version
|
||||
ProjectType ProjectType // agent or app
|
||||
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
func WithProjectInfo(info *ProjectInfo) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.ProjectInfo = info
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolVersion(version string) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.ToolVersion = version
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenapiOperation(op *Openapi3Operation) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.Operation = op
|
||||
}
|
||||
}
|
||||
|
||||
func WithInvalidRespProcessStrategy(strategy InvalidResponseProcessStrategy) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.InvalidRespProcessStrategy = strategy
|
||||
}
|
||||
}
|
||||
|
||||
func WithAutoGenRespSchema() ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.AutoGenRespSchema = true
|
||||
}
|
||||
}
|
||||
153
backend/api/model/crossdomain/plugin/plugin.go
Normal file
153
backend/api/model/crossdomain/plugin/plugin.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
)
|
||||
|
||||
type VersionPlugin struct {
|
||||
PluginID int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type VersionTool struct {
|
||||
ToolID int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type MGetPluginLatestVersionResponse struct {
|
||||
Versions map[int64]string // pluginID vs version
|
||||
}
|
||||
|
||||
type PluginInfo struct {
|
||||
ID int64
|
||||
PluginType api.PluginType
|
||||
SpaceID int64
|
||||
DeveloperID int64
|
||||
APPID *int64
|
||||
RefProductID *int64 // for product plugin
|
||||
IconURI *string
|
||||
ServerURL *string
|
||||
Version *string
|
||||
VersionDesc *string
|
||||
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
|
||||
Manifest *PluginManifest
|
||||
OpenapiDoc *Openapi3T
|
||||
}
|
||||
|
||||
func (p PluginInfo) SetName(name string) {
|
||||
if p.Manifest == nil || p.OpenapiDoc == nil {
|
||||
return
|
||||
}
|
||||
p.Manifest.NameForModel = name
|
||||
p.Manifest.NameForHuman = name
|
||||
p.OpenapiDoc.Info.Title = name
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetName() string {
|
||||
if p.Manifest == nil {
|
||||
return ""
|
||||
}
|
||||
return p.Manifest.NameForHuman
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetDesc() string {
|
||||
if p.Manifest == nil {
|
||||
return ""
|
||||
}
|
||||
return p.Manifest.DescriptionForHuman
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetAuthInfo() *AuthV2 {
|
||||
if p.Manifest == nil {
|
||||
return nil
|
||||
}
|
||||
return p.Manifest.Auth
|
||||
}
|
||||
|
||||
func (p PluginInfo) IsOfficial() bool {
|
||||
return p.RefProductID != nil
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetIconURI() string {
|
||||
if p.IconURI == nil {
|
||||
return ""
|
||||
}
|
||||
return *p.IconURI
|
||||
}
|
||||
|
||||
func (p PluginInfo) Published() bool {
|
||||
return p.Version != nil
|
||||
}
|
||||
|
||||
type VersionAgentTool struct {
|
||||
ToolName *string
|
||||
ToolID int64
|
||||
|
||||
AgentVersion *string
|
||||
}
|
||||
|
||||
type MGetAgentToolsRequest struct {
|
||||
AgentID int64
|
||||
SpaceID int64
|
||||
IsDraft bool
|
||||
|
||||
VersionAgentTools []VersionAgentTool
|
||||
}
|
||||
|
||||
type ExecuteToolRequest struct {
|
||||
UserID string
|
||||
PluginID int64
|
||||
ToolID int64
|
||||
ExecDraftTool bool // if true, execute draft tool
|
||||
ExecScene ExecuteScene
|
||||
|
||||
ArgumentsInJson string
|
||||
}
|
||||
|
||||
type ExecuteToolResponse struct {
|
||||
Tool *ToolInfo
|
||||
Request string
|
||||
TrimmedResp string
|
||||
RawResp string
|
||||
|
||||
RespSchema openapi3.Responses
|
||||
}
|
||||
|
||||
type PublishPluginRequest struct {
|
||||
PluginID int64
|
||||
Version string
|
||||
VersionDesc string
|
||||
}
|
||||
|
||||
type PublishAPPPluginsRequest struct {
|
||||
APPID int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type PublishAPPPluginsResponse struct {
|
||||
FailedPlugins []*PluginInfo
|
||||
AllDraftPlugins []*PluginInfo
|
||||
}
|
||||
|
||||
type CheckCanPublishPluginsRequest struct {
|
||||
PluginIDs []int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type CheckCanPublishPluginsResponse struct {
|
||||
InvalidPlugins []*PluginInfo
|
||||
}
|
||||
|
||||
type ToolInterruptEvent struct {
|
||||
Event InterruptEventType
|
||||
ToolNeedOAuth *ToolNeedOAuthInterruptEvent
|
||||
}
|
||||
|
||||
type ToolNeedOAuthInterruptEvent struct {
|
||||
Message string
|
||||
}
|
||||
497
backend/api/model/crossdomain/plugin/plugin_manifest.go
Normal file
497
backend/api/model/crossdomain/plugin/plugin_manifest.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
type PluginManifest struct {
|
||||
SchemaVersion string `json:"schema_version" yaml:"schema_version"`
|
||||
NameForModel string `json:"name_for_model" yaml:"name_for_model"`
|
||||
NameForHuman string `json:"name_for_human" yaml:"name_for_human"`
|
||||
DescriptionForModel string `json:"description_for_model" yaml:"description_for_model"`
|
||||
DescriptionForHuman string `json:"description_for_human" yaml:"description_for_human"`
|
||||
Auth *AuthV2 `json:"auth" yaml:"auth"`
|
||||
LogoURL string `json:"logo_url" yaml:"logo_url"`
|
||||
API APIDesc `json:"api" yaml:"api"`
|
||||
CommonParams map[HTTPParamLocation][]*api.CommonParamSchema `json:"common_params" yaml:"common_params"`
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) Copy() (*PluginManifest, error) {
|
||||
if mf == nil {
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
b, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf_ := &PluginManifest{}
|
||||
err = json.Unmarshal(b, mf_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mf_, err
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
|
||||
if mf == nil || mf.Auth == nil {
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
mf_, err := mf.Copy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mf_.Auth.Payload == "" {
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), utils.AuthSecretKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf_.Auth.Payload = payload_
|
||||
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) Validate(skipAuthPayload bool) (err error) {
|
||||
if mf.SchemaVersion != "v1" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid schema version '%s'", mf.SchemaVersion))
|
||||
}
|
||||
if mf.NameForModel == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"name for model is required"))
|
||||
}
|
||||
if mf.NameForHuman == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"name for human is required"))
|
||||
}
|
||||
if mf.DescriptionForModel == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"description for model is required"))
|
||||
}
|
||||
if mf.DescriptionForHuman == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"description for human is required"))
|
||||
}
|
||||
if mf.API.Type != PluginTypeOfCloud {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid api type '%s'", mf.API.Type))
|
||||
}
|
||||
|
||||
err = mf.validateAuthInfo(skipAuthPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for loc := range mf.CommonParams {
|
||||
if loc != ParamInBody &&
|
||||
loc != ParamInHeader &&
|
||||
loc != ParamInQuery &&
|
||||
loc != ParamInPath {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid location '%s' in common params", loc))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateAuthInfo(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"auth is required"))
|
||||
}
|
||||
|
||||
if mf.Auth.Payload != "" {
|
||||
js := json.RawMessage{}
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &js)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if mf.Auth.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"auth type is required"))
|
||||
}
|
||||
|
||||
if mf.Auth.Type != AuthzTypeOfNone &&
|
||||
mf.Auth.Type != AuthzTypeOfOAuth &&
|
||||
mf.Auth.Type != AuthzTypeOfService {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid auth type '%s'", mf.Auth.Type))
|
||||
}
|
||||
|
||||
if mf.Auth.Type == AuthzTypeOfNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
if mf.Auth.SubType == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"sub-auth type is required"))
|
||||
}
|
||||
|
||||
switch mf.Auth.SubType {
|
||||
case AuthzSubTypeOfServiceAPIToken:
|
||||
err = mf.validateServiceToken(skipAuthPayload)
|
||||
//case AuthzSubTypeOfOAuthClientCredentials:
|
||||
// err = mf.validateClientCredentials()
|
||||
case AuthzSubTypeOfOAuthAuthorizationCode:
|
||||
err = mf.validateAuthCode(skipAuthPayload)
|
||||
default:
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid sub-auth type '%s'", mf.Auth.SubType))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateServiceToken(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth.AuthOfAPIToken == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfAPIToken)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if skipAuthPayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiToken := mf.Auth.AuthOfAPIToken
|
||||
|
||||
if apiToken.ServiceToken == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"service token is required"))
|
||||
}
|
||||
if apiToken.Key == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"key is required"))
|
||||
}
|
||||
|
||||
loc := HTTPParamLocation(strings.ToLower(string(apiToken.Location)))
|
||||
if loc != ParamInHeader && loc != ParamInQuery {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid location '%s'", apiToken.Location))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateClientCredentials() (err error) {
|
||||
if mf.Auth.AuthOfOAuthClientCredentials == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthClientCredentials)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
clientCredentials := mf.Auth.AuthOfOAuthClientCredentials
|
||||
|
||||
if clientCredentials.ClientID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client id is required"))
|
||||
}
|
||||
if clientCredentials.ClientSecret == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client secret is required"))
|
||||
}
|
||||
if clientCredentials.TokenURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url is required"))
|
||||
}
|
||||
|
||||
urlParse, err := url.Parse(clientCredentials.TokenURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid token url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateAuthCode(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth.AuthOfOAuthAuthorizationCode == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthAuthorizationCode)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if skipAuthPayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
authCode := mf.Auth.AuthOfOAuthAuthorizationCode
|
||||
|
||||
if authCode.ClientID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client id is required"))
|
||||
}
|
||||
if authCode.ClientSecret == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client secret is required"))
|
||||
}
|
||||
if authCode.AuthorizationContentType != MediaTypeJson {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"authorization content type must be 'application/json'"))
|
||||
}
|
||||
if authCode.AuthorizationURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url is required"))
|
||||
}
|
||||
if authCode.ClientURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client url is required"))
|
||||
}
|
||||
|
||||
urlParse, err := url.Parse(authCode.AuthorizationURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid authorization url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"authorization url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
urlParse, err = url.Parse(authCode.ClientURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid client url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
Type string `json:"type" validate:"required"`
|
||||
AuthorizationType string `json:"authorization_type,omitempty"`
|
||||
ClientURL string `json:"client_url,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
AuthorizationURL string `json:"authorization_url,omitempty"`
|
||||
AuthorizationContentType string `json:"authorization_content_type,omitempty"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
Location string `json:"location,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
ServiceToken string `json:"service_token,omitempty"`
|
||||
SubType string `json:"sub_type"`
|
||||
Payload string `json:"payload"`
|
||||
}
|
||||
|
||||
type AuthV2 struct {
|
||||
Type AuthzType `json:"type" yaml:"type"`
|
||||
SubType AuthzSubType `json:"sub_type" yaml:"sub_type"`
|
||||
Payload string `json:"payload" yaml:"payload"`
|
||||
// service
|
||||
AuthOfAPIToken *AuthOfAPIToken `json:"-"`
|
||||
|
||||
// oauth
|
||||
AuthOfOAuthAuthorizationCode *OAuthAuthorizationCodeConfig `json:"-"`
|
||||
AuthOfOAuthClientCredentials *OAuthClientCredentialsConfig `json:"-"`
|
||||
}
|
||||
|
||||
func (au *AuthV2) UnmarshalJSON(data []byte) error {
|
||||
auth := &Auth{} // 兼容老数据
|
||||
err := json.Unmarshal(data, auth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid plugin manifest json"))
|
||||
}
|
||||
|
||||
au.Type = AuthzType(auth.Type)
|
||||
au.SubType = AuthzSubType(auth.SubType)
|
||||
|
||||
if au.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"plugin auth type is required"))
|
||||
}
|
||||
|
||||
if auth.Payload != "" {
|
||||
payload_, err := utils.DecryptByAES(auth.Payload, utils.AuthSecretKey)
|
||||
if err == nil {
|
||||
auth.Payload = string(payload_)
|
||||
}
|
||||
}
|
||||
|
||||
switch au.Type {
|
||||
case AuthzTypeOfNone:
|
||||
case AuthzTypeOfOAuth:
|
||||
err = au.unmarshalOAuth(auth)
|
||||
case AuthzTypeOfService:
|
||||
err = au.unmarshalService(auth)
|
||||
default:
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin auth type '%s'", au.Type))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (au *AuthV2) unmarshalService(auth *Auth) (err error) {
|
||||
if au.SubType == "" && au.Payload == "" { // 兼容老数据
|
||||
au.SubType = AuthzSubTypeOfServiceAPIToken
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
|
||||
if au.SubType == AuthzSubTypeOfServiceAPIToken {
|
||||
if len(auth.ServiceToken) > 0 {
|
||||
au.AuthOfAPIToken = &AuthOfAPIToken{
|
||||
Location: HTTPParamLocation(strings.ToLower(auth.Location)),
|
||||
Key: auth.Key,
|
||||
ServiceToken: auth.ServiceToken,
|
||||
}
|
||||
} else {
|
||||
token := &AuthOfAPIToken{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), token)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfAPIToken = token
|
||||
}
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfAPIToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin sub-auth type '%s'", au.SubType))
|
||||
}
|
||||
|
||||
au.Payload = string(payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) {
|
||||
if au.SubType == "" { // 兼容老数据
|
||||
au.SubType = AuthzSubTypeOfOAuthAuthorizationCode
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
|
||||
if au.SubType == AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
if len(auth.ClientSecret) > 0 {
|
||||
au.AuthOfOAuthAuthorizationCode = &OAuthAuthorizationCodeConfig{
|
||||
ClientID: auth.ClientID,
|
||||
ClientSecret: auth.ClientSecret,
|
||||
ClientURL: auth.ClientURL,
|
||||
Scope: auth.Scope,
|
||||
AuthorizationURL: auth.AuthorizationURL,
|
||||
AuthorizationContentType: auth.AuthorizationContentType,
|
||||
}
|
||||
} else {
|
||||
oauth := &OAuthAuthorizationCodeConfig{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), oauth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfOAuthAuthorizationCode = oauth
|
||||
}
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfOAuthAuthorizationCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if au.SubType == AuthzSubTypeOfOAuthClientCredentials {
|
||||
oauth := &OAuthClientCredentialsConfig{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), oauth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfOAuthClientCredentials = oauth
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfOAuthClientCredentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin sub-auth type '%s'", au.SubType))
|
||||
}
|
||||
|
||||
au.Payload = string(payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthOfAPIToken struct {
|
||||
// Location is the location of the parameter.
|
||||
// It can be "header" or "query".
|
||||
Location HTTPParamLocation `json:"location"`
|
||||
// Key is the name of the parameter.
|
||||
Key string `json:"key"`
|
||||
// ServiceToken is the simple authorization information for the service.
|
||||
ServiceToken string `json:"service_token"`
|
||||
}
|
||||
|
||||
type OAuthAuthorizationCodeConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
// ClientURL is the URL of authorization endpoint.
|
||||
ClientURL string `json:"client_url"`
|
||||
// Scope is the scope of the authorization request.
|
||||
// If multiple scopes are requested, they must be separated by a space.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// AuthorizationURL is the URL of token exchange endpoint.
|
||||
AuthorizationURL string `json:"authorization_url"`
|
||||
// AuthorizationContentType is the content type of the authorization request, and it must be "application/json".
|
||||
AuthorizationContentType string `json:"authorization_content_type"`
|
||||
}
|
||||
|
||||
type OAuthClientCredentialsConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
TokenURL string `json:"token_url"`
|
||||
}
|
||||
|
||||
type APIDesc struct {
|
||||
Type PluginType `json:"type" validate:"required"`
|
||||
}
|
||||
566
backend/api/model/crossdomain/plugin/toolinfo.go
Normal file
566
backend/api/model/crossdomain/plugin/toolinfo.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
gonanoid "github.com/matoous/go-nanoid"
|
||||
|
||||
productAPI "github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type ToolInfo struct {
|
||||
ID int64
|
||||
PluginID int64
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
Version *string
|
||||
|
||||
ActivatedStatus *ActivatedStatus
|
||||
DebugStatus *plugin_develop_common.APIDebugStatus
|
||||
|
||||
Method *string
|
||||
SubURL *string
|
||||
Operation *Openapi3Operation
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetName() string {
|
||||
if t.Operation == nil {
|
||||
return ""
|
||||
}
|
||||
return t.Operation.OperationID
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetDesc() string {
|
||||
if t.Operation == nil {
|
||||
return ""
|
||||
}
|
||||
return t.Operation.Summary
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetVersion() string {
|
||||
return ptr.FromOrDefault(t.Version, "")
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetActivatedStatus() ActivatedStatus {
|
||||
return ptr.FromOrDefault(t.ActivatedStatus, ActivateTool)
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetSubURL() string {
|
||||
return ptr.FromOrDefault(t.SubURL, "")
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetMethod() string {
|
||||
return strings.ToUpper(ptr.FromOrDefault(t.Method, ""))
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetDebugStatus() common.APIDebugStatus {
|
||||
return ptr.FromOrDefault(t.DebugStatus, common.APIDebugStatus_DebugWaiting)
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetResponseOpenapiSchema() (*openapi3.Schema, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
resp, ok := op.Responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok || resp == nil || resp.Value == nil || len(resp.Value.Content) == 0 {
|
||||
return nil, fmt.Errorf("response status '200' not found")
|
||||
}
|
||||
|
||||
mType, ok := resp.Value.Content[MediaTypeJson] // only support application/json
|
||||
if !ok || mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("media type '%s' not found in response", MediaTypeJson)
|
||||
}
|
||||
|
||||
return mType.Schema.Value, nil
|
||||
}
|
||||
|
||||
type paramMetaInfo struct {
|
||||
name string
|
||||
desc string
|
||||
required bool
|
||||
location string
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToRespAPIParameter() ([]*common.APIParameter, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
respSchema, err := t.GetResponseOpenapiSchema()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := make([]*common.APIParameter, 0, len(op.Parameters))
|
||||
if len(respSchema.Properties) == 0 {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
required := slices.ToMap(respSchema.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for subParamName, prop := range respSchema.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
location: string(ParamInBody),
|
||||
required: required[subParamName],
|
||||
}
|
||||
apiParam, err := toAPIParameter(paramMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params = append(params, apiParam)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToReqAPIParameter() ([]*common.APIParameter, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
params := make([]*common.APIParameter, 0, len(op.Parameters))
|
||||
for _, param := range op.Parameters {
|
||||
if param == nil || param.Value == nil || param.Value.Schema == nil || param.Value.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("parameter schema is required")
|
||||
}
|
||||
|
||||
paramVal := param.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
|
||||
if schemaVal.Type == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of parameter '%s' cannot be 'object'", paramVal.Name)
|
||||
}
|
||||
|
||||
if schemaVal.Type == openapi3.TypeArray {
|
||||
if paramVal.In == openapi3.ParameterInPath {
|
||||
return nil, fmt.Errorf("the type of field '%s' cannot be 'array'", paramVal.Name)
|
||||
}
|
||||
if schemaVal.Items == nil || schemaVal.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramVal.Name)
|
||||
}
|
||||
item := schemaVal.Items.Value
|
||||
if item.Type == openapi3.TypeObject || item.Type == openapi3.TypeArray {
|
||||
return nil, fmt.Errorf("the item type of parameter '%s' cannot be 'object' or 'array'", paramVal.Name)
|
||||
}
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: paramVal.Name,
|
||||
desc: paramVal.Description,
|
||||
location: paramVal.In,
|
||||
required: paramVal.Required,
|
||||
}
|
||||
apiParam, err := toAPIParameter(paramMeta, schemaVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params = append(params, apiParam)
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
if mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("request body schema is required")
|
||||
}
|
||||
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for subParamName, prop := range schemaVal.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
location: string(ParamInBody),
|
||||
required: required[subParamName],
|
||||
}
|
||||
apiParam, err := toAPIParameter(paramMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params = append(params, apiParam)
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIParameter, error) {
|
||||
if sc == nil {
|
||||
return nil, fmt.Errorf("schema is requred")
|
||||
}
|
||||
|
||||
apiType, ok := ToThriftParamType(strings.ToLower(sc.Type))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the type '%s' of filed '%s' is invalid", sc.Type, paramMeta.name)
|
||||
}
|
||||
location, ok := ToThriftHTTPParamLocation(HTTPParamLocation(paramMeta.location))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the location '%s' of field '%s' is invalid", paramMeta.location, paramMeta.name)
|
||||
}
|
||||
|
||||
apiParam := &common.APIParameter{
|
||||
ID: gonanoid.MustID(10),
|
||||
Name: paramMeta.name,
|
||||
Desc: paramMeta.desc,
|
||||
Type: apiType,
|
||||
Location: location, // 使用父节点的值
|
||||
IsRequired: paramMeta.required,
|
||||
SubParameters: []*common.APIParameter{},
|
||||
}
|
||||
|
||||
if sc.Default != nil {
|
||||
apiParam.LocalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
|
||||
}
|
||||
|
||||
if sc.Format != "" {
|
||||
aType, ok := FormatToAssistType(sc.Format)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the format '%s' of field '%s' is invalid", sc.Format, paramMeta.name)
|
||||
}
|
||||
_aType, ok := ToThriftAPIAssistType(aType)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("assist type '%s' of field '%s' is invalid", aType, paramMeta.name)
|
||||
}
|
||||
apiParam.AssistType = ptr.Of(_aType)
|
||||
}
|
||||
|
||||
if v, ok := sc.Extensions[APISchemaExtendGlobalDisable]; ok {
|
||||
if disable, ok := v.(bool); ok {
|
||||
apiParam.GlobalDisable = disable
|
||||
}
|
||||
}
|
||||
if v, ok := sc.Extensions[APISchemaExtendLocalDisable]; ok {
|
||||
if disable, ok := v.(bool); ok {
|
||||
apiParam.LocalDisable = disable
|
||||
}
|
||||
}
|
||||
if v, ok := sc.Extensions[APISchemaExtendVariableRef]; ok {
|
||||
if ref, ok := v.(string); ok {
|
||||
apiParam.VariableRef = ptr.Of(ref)
|
||||
apiParam.DefaultParamSource = ptr.Of(common.DefaultParamSource_Variable)
|
||||
}
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
if len(sc.Properties) == 0 {
|
||||
return apiParam, nil
|
||||
}
|
||||
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
for subParamName, prop := range sc.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
required: required[subParamName],
|
||||
location: paramMeta.location,
|
||||
}
|
||||
subParam, err := toAPIParameter(subMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiParam.SubParameters = append(apiParam.SubParameters, subParam)
|
||||
}
|
||||
|
||||
return apiParam, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
if sc.Items == nil || sc.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramMeta.name)
|
||||
}
|
||||
|
||||
item := sc.Items.Value
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
name: "[Array Item]",
|
||||
desc: item.Description,
|
||||
location: paramMeta.location,
|
||||
required: paramMeta.required,
|
||||
}
|
||||
subParam, err := toAPIParameter(subMeta, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiParam.SubParameters = append(apiParam.SubParameters, subParam)
|
||||
|
||||
return apiParam, nil
|
||||
}
|
||||
|
||||
return apiParam, nil
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToPluginParameters() ([]*common.PluginParameter, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
var params []*common.PluginParameter
|
||||
|
||||
for _, prop := range op.Parameters {
|
||||
if prop == nil || prop.Value == nil || prop.Value.Schema == nil || prop.Value.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("parameter schema is required")
|
||||
}
|
||||
|
||||
paramVal := prop.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
|
||||
if schemaVal.Type == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of parameter '%s' cannot be 'object'", paramVal.Name)
|
||||
}
|
||||
|
||||
var arrayItemType string
|
||||
if schemaVal.Type == openapi3.TypeArray {
|
||||
if paramVal.In == openapi3.ParameterInPath {
|
||||
return nil, fmt.Errorf("the type of field '%s' cannot be 'array'", paramVal.Name)
|
||||
}
|
||||
if schemaVal.Items == nil || schemaVal.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramVal.Name)
|
||||
}
|
||||
item := schemaVal.Items.Value
|
||||
if item.Type == openapi3.TypeObject || item.Type == openapi3.TypeArray {
|
||||
return nil, fmt.Errorf("the item type of parameter '%s' cannot be 'object' or 'array'", paramVal.Name)
|
||||
}
|
||||
|
||||
arrayItemType = item.Type
|
||||
}
|
||||
|
||||
if disabledParam(schemaVal) {
|
||||
continue
|
||||
}
|
||||
|
||||
var assistType *common.PluginParamTypeFormat
|
||||
if v, ok := schemaVal.Extensions[APISchemaExtendAssistType]; ok {
|
||||
_v, ok := v.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v))
|
||||
if ok {
|
||||
return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramVal.Name)
|
||||
}
|
||||
assistType = ptr.Of(f)
|
||||
}
|
||||
|
||||
params = append(params, &common.PluginParameter{
|
||||
Name: paramVal.Name,
|
||||
Desc: paramVal.Description,
|
||||
Required: paramVal.Required,
|
||||
Type: schemaVal.Type,
|
||||
SubType: arrayItemType,
|
||||
Format: assistType,
|
||||
SubParameters: []*common.PluginParameter{},
|
||||
})
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
if mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("request body schema is required")
|
||||
}
|
||||
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for subParamName, prop := range schemaVal.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
required: required[subParamName],
|
||||
}
|
||||
paramInfo, err := toPluginParameter(paramMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paramInfo != nil {
|
||||
params = append(params, paramInfo)
|
||||
}
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func toPluginParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.PluginParameter, error) {
|
||||
if sc == nil {
|
||||
return nil, fmt.Errorf("schema is required")
|
||||
}
|
||||
|
||||
if disabledParam(sc) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var assistType *common.PluginParamTypeFormat
|
||||
if v, ok := sc.Extensions[APISchemaExtendAssistType]; ok {
|
||||
if _v, ok := v.(string); ok {
|
||||
f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramMeta.name)
|
||||
}
|
||||
assistType = ptr.Of(f)
|
||||
}
|
||||
}
|
||||
|
||||
pluginParam := &common.PluginParameter{
|
||||
Name: paramMeta.name,
|
||||
Type: sc.Type,
|
||||
Desc: paramMeta.desc,
|
||||
Required: paramMeta.required,
|
||||
Format: assistType,
|
||||
SubParameters: []*common.PluginParameter{},
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
if len(sc.Properties) == 0 {
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
for subParamName, prop := range sc.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
required: required[subParamName],
|
||||
}
|
||||
subParam, err := toPluginParameter(subMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pluginParam.SubParameters = append(pluginParam.SubParameters, subParam)
|
||||
}
|
||||
|
||||
return pluginParam, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
if sc.Items == nil || sc.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramMeta.name)
|
||||
}
|
||||
|
||||
item := sc.Items.Value
|
||||
pluginParam.SubType = item.Type
|
||||
|
||||
if item.Type != openapi3.TypeObject {
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
desc: item.Description,
|
||||
required: paramMeta.required,
|
||||
}
|
||||
subParam, err := toPluginParameter(subMeta, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pluginParam.SubParameters = append(pluginParam.SubParameters, subParam.SubParameters...)
|
||||
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToToolParameters() ([]*productAPI.ToolParameter, error) {
|
||||
apiParams, err := t.ToReqAPIParameter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toToolParams func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error)
|
||||
toToolParams = func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error) {
|
||||
params := make([]*productAPI.ToolParameter, 0, len(apiParams))
|
||||
for _, apiParam := range apiParams {
|
||||
typ, _ := ToOpenapiParamType(apiParam.Type)
|
||||
toolParam := &productAPI.ToolParameter{
|
||||
Name: apiParam.Name,
|
||||
Description: apiParam.Desc,
|
||||
Type: typ,
|
||||
IsRequired: apiParam.IsRequired,
|
||||
SubParameter: []*productAPI.ToolParameter{},
|
||||
}
|
||||
|
||||
if len(apiParam.SubParameters) > 0 {
|
||||
subParams, err := toToolParams(apiParam.SubParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolParam.SubParameter = append(toolParam.SubParameter, subParams...)
|
||||
}
|
||||
|
||||
params = append(params, toolParam)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
return toToolParams(apiParams)
|
||||
}
|
||||
Reference in New Issue
Block a user