feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,110 @@
package plugin
import "github.com/getkin/kin-openapi/openapi3"
type PluginType string
const (
PluginTypeOfCloud PluginType = "openapi"
)
type AuthzType string
const (
AuthzTypeOfNone AuthzType = "none"
AuthzTypeOfService AuthzType = "service_http"
AuthzTypeOfOAuth AuthzType = "oauth"
)
type AuthzSubType string
const (
AuthzSubTypeOfServiceAPIToken AuthzSubType = "token/api_key"
AuthzSubTypeOfOAuthAuthorizationCode AuthzSubType = "authorization_code"
AuthzSubTypeOfOAuthClientCredentials AuthzSubType = "client_credentials"
)
type HTTPParamLocation string
const (
ParamInHeader HTTPParamLocation = openapi3.ParameterInHeader
ParamInPath HTTPParamLocation = openapi3.ParameterInPath
ParamInQuery HTTPParamLocation = openapi3.ParameterInQuery
ParamInBody HTTPParamLocation = "body"
)
type ActivatedStatus int32
const (
ActivateTool ActivatedStatus = 0
DeactivateTool ActivatedStatus = 1
)
type ProjectType int8
const (
ProjectTypeOfAgent ProjectType = 1
ProjectTypeOfAPP ProjectType = 2
)
type ExecuteScene string
const (
ExecSceneOfOnlineAgent ExecuteScene = "online_agent"
ExecSceneOfDraftAgent ExecuteScene = "draft_agent"
ExecSceneOfWorkflow ExecuteScene = "workflow"
ExecSceneOfToolDebug ExecuteScene = "tool_debug"
)
type InvalidResponseProcessStrategy int8
const (
InvalidResponseProcessStrategyOfReturnRaw InvalidResponseProcessStrategy = 0 // If the value of a field is invalid, the raw response value of the field is returned.
InvalidResponseProcessStrategyOfReturnDefault InvalidResponseProcessStrategy = 1 // If the value of a field is invalid, the default value of the field is returned.
)
const (
APISchemaExtendAssistType = "x-assist-type"
APISchemaExtendGlobalDisable = "x-global-disable"
APISchemaExtendLocalDisable = "x-local-disable"
APISchemaExtendVariableRef = "x-variable-ref"
APISchemaExtendAuthMode = "x-auth-mode"
)
type ToolAuthMode string
const (
ToolAuthModeOfRequired ToolAuthMode = "required"
ToolAuthModeOfSupported ToolAuthMode = "supported"
ToolAuthModeOfDisabled ToolAuthMode = "disabled"
)
type APIFileAssistType string
const (
AssistTypeFile APIFileAssistType = "file"
AssistTypeImage APIFileAssistType = "image"
AssistTypeDoc APIFileAssistType = "doc"
AssistTypePPT APIFileAssistType = "ppt"
AssistTypeCode APIFileAssistType = "code"
AssistTypeExcel APIFileAssistType = "excel"
AssistTypeZIP APIFileAssistType = "zip"
AssistTypeVideo APIFileAssistType = "video"
AssistTypeAudio APIFileAssistType = "audio"
AssistTypeTXT APIFileAssistType = "txt"
)
type CopyScene string
const (
CopySceneOfToAPP CopyScene = "to_app"
CopySceneOfToLibrary CopyScene = "to_library"
CopySceneOfDuplicate CopyScene = "duplicate"
CopySceneOfAPPDuplicate CopyScene = "app_duplicate"
)
type InterruptEventType string
const (
InterruptEventTypeOfToolNeedOAuth InterruptEventType = "tool_need_oauth"
)

View File

@@ -0,0 +1,270 @@
package plugin
import (
"net/http"
"github.com/getkin/kin-openapi/openapi3"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
)
var httpParamLocations = map[common.ParameterLocation]HTTPParamLocation{
common.ParameterLocation_Path: ParamInPath,
common.ParameterLocation_Query: ParamInQuery,
common.ParameterLocation_Body: ParamInBody,
common.ParameterLocation_Header: ParamInHeader,
}
func ToHTTPParamLocation(loc common.ParameterLocation) (HTTPParamLocation, bool) {
_loc, ok := httpParamLocations[loc]
return _loc, ok
}
var thriftHTTPParamLocations = func() map[HTTPParamLocation]common.ParameterLocation {
locations := make(map[HTTPParamLocation]common.ParameterLocation, len(httpParamLocations))
for k, v := range httpParamLocations {
locations[v] = k
}
return locations
}()
func ToThriftHTTPParamLocation(loc HTTPParamLocation) (common.ParameterLocation, bool) {
_loc, ok := thriftHTTPParamLocations[loc]
return _loc, ok
}
var openapiTypes = map[common.ParameterType]string{
common.ParameterType_String: openapi3.TypeString,
common.ParameterType_Integer: openapi3.TypeInteger,
common.ParameterType_Number: openapi3.TypeNumber,
common.ParameterType_Object: openapi3.TypeObject,
common.ParameterType_Array: openapi3.TypeArray,
common.ParameterType_Bool: openapi3.TypeBoolean,
}
func ToOpenapiParamType(typ common.ParameterType) (string, bool) {
_typ, ok := openapiTypes[typ]
return _typ, ok
}
var thriftParameterTypes = func() map[string]common.ParameterType {
types := make(map[string]common.ParameterType, len(openapiTypes))
for k, v := range openapiTypes {
types[v] = k
}
return types
}()
func ToThriftParamType(typ string) (common.ParameterType, bool) {
_typ, ok := thriftParameterTypes[typ]
return _typ, ok
}
var apiAssistTypes = map[common.AssistParameterType]APIFileAssistType{
common.AssistParameterType_DEFAULT: AssistTypeFile,
common.AssistParameterType_IMAGE: AssistTypeImage,
common.AssistParameterType_DOC: AssistTypeDoc,
common.AssistParameterType_PPT: AssistTypePPT,
common.AssistParameterType_CODE: AssistTypeCode,
common.AssistParameterType_EXCEL: AssistTypeExcel,
common.AssistParameterType_ZIP: AssistTypeZIP,
common.AssistParameterType_VIDEO: AssistTypeVideo,
common.AssistParameterType_AUDIO: AssistTypeAudio,
common.AssistParameterType_TXT: AssistTypeTXT,
}
func ToAPIAssistType(typ common.AssistParameterType) (APIFileAssistType, bool) {
_typ, ok := apiAssistTypes[typ]
return _typ, ok
}
var thriftAPIAssistTypes = func() map[APIFileAssistType]common.AssistParameterType {
types := make(map[APIFileAssistType]common.AssistParameterType, len(apiAssistTypes))
for k, v := range apiAssistTypes {
types[v] = k
}
return types
}()
func ToThriftAPIAssistType(typ APIFileAssistType) (common.AssistParameterType, bool) {
_typ, ok := thriftAPIAssistTypes[typ]
return _typ, ok
}
func IsValidAPIAssistType(typ APIFileAssistType) bool {
_, ok := thriftAPIAssistTypes[typ]
return ok
}
var httpMethods = map[common.APIMethod]string{
common.APIMethod_GET: http.MethodGet,
common.APIMethod_POST: http.MethodPost,
common.APIMethod_PUT: http.MethodPut,
common.APIMethod_DELETE: http.MethodDelete,
common.APIMethod_PATCH: http.MethodPatch,
}
var thriftAPIMethods = func() map[string]common.APIMethod {
methods := make(map[string]common.APIMethod, len(httpMethods))
for k, v := range httpMethods {
methods[v] = k
}
return methods
}()
func ToThriftAPIMethod(method string) (common.APIMethod, bool) {
_method, ok := thriftAPIMethods[method]
return _method, ok
}
func ToHTTPMethod(method common.APIMethod) (string, bool) {
_method, ok := httpMethods[method]
return _method, ok
}
var assistTypeToFormat = map[APIFileAssistType]string{
AssistTypeFile: "file_url",
AssistTypeImage: "image_url",
AssistTypeDoc: "doc_url",
AssistTypePPT: "ppt_url",
AssistTypeCode: "code_url",
AssistTypeExcel: "excel_url",
AssistTypeZIP: "zip_url",
AssistTypeVideo: "video_url",
AssistTypeAudio: "audio_url",
AssistTypeTXT: "txt_url",
}
func AssistTypeToFormat(typ APIFileAssistType) (string, bool) {
format, ok := assistTypeToFormat[typ]
return format, ok
}
var formatToAssistType = func() map[string]APIFileAssistType {
types := make(map[string]APIFileAssistType, len(assistTypeToFormat))
for k, v := range assistTypeToFormat {
types[v] = k
}
return types
}()
func FormatToAssistType(format string) (APIFileAssistType, bool) {
typ, ok := formatToAssistType[format]
return typ, ok
}
var assistTypeToThriftFormat = map[APIFileAssistType]common.PluginParamTypeFormat{
AssistTypeFile: common.PluginParamTypeFormat_FileUrl,
AssistTypeImage: common.PluginParamTypeFormat_ImageUrl,
AssistTypeDoc: common.PluginParamTypeFormat_DocUrl,
AssistTypePPT: common.PluginParamTypeFormat_PptUrl,
AssistTypeCode: common.PluginParamTypeFormat_CodeUrl,
AssistTypeExcel: common.PluginParamTypeFormat_ExcelUrl,
AssistTypeZIP: common.PluginParamTypeFormat_ZipUrl,
AssistTypeVideo: common.PluginParamTypeFormat_VideoUrl,
AssistTypeAudio: common.PluginParamTypeFormat_AudioUrl,
AssistTypeTXT: common.PluginParamTypeFormat_TxtUrl,
}
func AssistTypeToThriftFormat(typ APIFileAssistType) (common.PluginParamTypeFormat, bool) {
format, ok := assistTypeToThriftFormat[typ]
return format, ok
}
var authTypes = map[common.AuthorizationType]AuthzType{
common.AuthorizationType_None: AuthzTypeOfNone,
common.AuthorizationType_Service: AuthzTypeOfService,
common.AuthorizationType_OAuth: AuthzTypeOfOAuth,
common.AuthorizationType_Standard: AuthzTypeOfOAuth, // deprecated, the same as OAuth
}
func ToAuthType(typ common.AuthorizationType) (AuthzType, bool) {
_type, ok := authTypes[typ]
return _type, ok
}
var thriftAuthTypes = func() map[AuthzType]common.AuthorizationType {
types := make(map[AuthzType]common.AuthorizationType, len(authTypes))
for k, v := range authTypes {
if v == AuthzTypeOfOAuth {
types[v] = common.AuthorizationType_OAuth
} else {
types[v] = k
}
}
return types
}()
func ToThriftAuthType(typ AuthzType) (common.AuthorizationType, bool) {
_type, ok := thriftAuthTypes[typ]
return _type, ok
}
var subAuthTypes = map[int32]AuthzSubType{
int32(common.ServiceAuthSubType_ApiKey): AuthzSubTypeOfServiceAPIToken,
int32(common.ServiceAuthSubType_OAuthAuthorizationCode): AuthzSubTypeOfOAuthAuthorizationCode,
}
func ToAuthSubType(typ int32) (AuthzSubType, bool) {
_type, ok := subAuthTypes[typ]
return _type, ok
}
var thriftSubAuthTypes = func() map[AuthzSubType]int32 {
types := make(map[AuthzSubType]int32, len(subAuthTypes))
for k, v := range subAuthTypes {
types[v] = int32(k)
}
return types
}()
func ToThriftAuthSubType(typ AuthzSubType) (int32, bool) {
_type, ok := thriftSubAuthTypes[typ]
return _type, ok
}
var pluginTypes = map[common.PluginType]PluginType{
common.PluginType_PLUGIN: PluginTypeOfCloud,
}
func ToPluginType(typ common.PluginType) (PluginType, bool) {
_type, ok := pluginTypes[typ]
return _type, ok
}
var thriftPluginTypes = func() map[PluginType]common.PluginType {
types := make(map[PluginType]common.PluginType, len(pluginTypes))
for k, v := range pluginTypes {
types[v] = k
}
return types
}()
func ToThriftPluginType(typ PluginType) (common.PluginType, bool) {
_type, ok := thriftPluginTypes[typ]
return _type, ok
}
var apiAuthModes = map[common.PluginToolAuthType]ToolAuthMode{
common.PluginToolAuthType_Required: ToolAuthModeOfRequired,
common.PluginToolAuthType_Supported: ToolAuthModeOfSupported,
common.PluginToolAuthType_Disable: ToolAuthModeOfDisabled,
}
func ToAPIAuthMode(mode common.PluginToolAuthType) (ToolAuthMode, bool) {
_mode, ok := apiAuthModes[mode]
return _mode, ok
}
var thriftAPIAuthModes = func() map[ToolAuthMode]common.PluginToolAuthType {
modes := make(map[ToolAuthMode]common.PluginToolAuthType, len(apiAuthModes))
for k, v := range apiAuthModes {
modes[v] = k
}
return modes
}()
func ToThriftAPIAuthMode(mode ToolAuthMode) (common.PluginToolAuthType, bool) {
_mode, ok := thriftAPIAuthModes[mode]
return _mode, ok
}

View File

@@ -0,0 +1,429 @@
package plugin
import (
"context"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/getkin/kin-openapi/openapi3"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/cloudwego/eino/schema"
)
type Openapi3T openapi3.T
func (ot Openapi3T) Validate(ctx context.Context) (err error) {
err = ptr.Of(openapi3.T(ot)).Validate(ctx)
if err != nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error()))
}
if ot.Info == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"info is required"))
}
if ot.Info.Title == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"the title of info is required"))
}
if ot.Info.Description == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"the description of info is required"))
}
if len(ot.Servers) != 1 {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"server is required and only one server is allowed"))
}
serverURL := ot.Servers[0].URL
urlSchema, err := url.Parse(serverURL)
if err != nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"invalid server url '%s'", serverURL))
}
if urlSchema.Scheme != "https" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"server url must start with 'https://'"))
}
if urlSchema.Host == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"invalid server url '%s'", serverURL))
}
if len(serverURL) > 512 {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"server url '%s' is too long", serverURL))
}
for _, pathItem := range ot.Paths {
for _, op := range pathItem.Operations() {
err = NewOpenapi3Operation(op).Validate(ctx)
if err != nil {
return err
}
}
}
return nil
}
func NewOpenapi3Operation(op *openapi3.Operation) *Openapi3Operation {
return &Openapi3Operation{
Operation: op,
}
}
type Openapi3Operation struct {
*openapi3.Operation
}
func (op *Openapi3Operation) MarshalJSON() ([]byte, error) {
return op.Operation.MarshalJSON()
}
func (op *Openapi3Operation) UnmarshalJSON(data []byte) error {
op.Operation = &openapi3.Operation{}
return op.Operation.UnmarshalJSON(data)
}
func (op *Openapi3Operation) Validate(ctx context.Context) (err error) {
err = op.Operation.Validate(ctx)
if err != nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey, "operation is invalid, err=%s", err))
}
if op.OperationID == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "operationID is required"))
}
if op.Summary == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "summary is required"))
}
err = validateOpenapi3RequestBody(op.RequestBody)
if err != nil {
return err
}
err = validateOpenapi3Parameters(op.Parameters)
if err != nil {
return err
}
err = validateOpenapi3Responses(op.Responses)
if err != nil {
return err
}
return nil
}
func (op *Openapi3Operation) ToEinoSchemaParameterInfo(ctx context.Context) (map[string]*schema.ParameterInfo, error) {
convertType := func(openapiType string) schema.DataType {
switch openapiType {
case openapi3.TypeString:
return schema.String
case openapi3.TypeInteger:
return schema.Integer
case openapi3.TypeObject:
return schema.Object
case openapi3.TypeArray:
return schema.Array
case openapi3.TypeBoolean:
return schema.Boolean
case openapi3.TypeNumber:
return schema.Number
default:
return schema.Null
}
}
var convertReqBody func(sc *openapi3.Schema, isRequired bool) (*schema.ParameterInfo, error)
convertReqBody = func(sc *openapi3.Schema, isRequired bool) (*schema.ParameterInfo, error) {
if disabledParam(sc) {
return nil, nil
}
paramInfo := &schema.ParameterInfo{
Type: convertType(sc.Type),
Desc: sc.Description,
Required: isRequired,
}
switch sc.Type {
case openapi3.TypeObject:
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
return e, true
})
subParams := make(map[string]*schema.ParameterInfo, len(sc.Properties))
for paramName, prop := range sc.Properties {
subParam, err := convertReqBody(prop.Value, required[paramName])
if err != nil {
return nil, err
}
subParams[paramName] = subParam
}
paramInfo.SubParams = subParams
case openapi3.TypeArray:
ele, err := convertReqBody(sc.Items.Value, isRequired)
if err != nil {
return nil, err
}
paramInfo.ElemInfo = ele
case openapi3.TypeString, openapi3.TypeInteger, openapi3.TypeBoolean, openapi3.TypeNumber:
return paramInfo, nil
default:
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KVf(errno.PluginMsgKey,
"unsupported json type '%s'", sc.Type))
}
return paramInfo, nil
}
result := make(map[string]*schema.ParameterInfo)
for _, prop := range op.Parameters {
paramVal := prop.Value
schemaVal := paramVal.Schema.Value
if schemaVal.Type == openapi3.TypeObject || schemaVal.Type == openapi3.TypeArray {
continue
}
if disabledParam(prop.Value.Schema.Value) {
continue
}
paramInfo := &schema.ParameterInfo{
Type: convertType(schemaVal.Type),
Desc: paramVal.Description,
Required: paramVal.Required,
}
if _, ok := result[paramVal.Name]; ok {
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramVal.Name)
continue
}
result[paramVal.Name] = paramInfo
}
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
return result, nil
}
for _, mType := range op.RequestBody.Value.Content {
schemaVal := mType.Schema.Value
if len(schemaVal.Properties) == 0 {
continue
}
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
return e, true
})
for paramName, prop := range schemaVal.Properties {
paramInfo, err := convertReqBody(prop.Value, required[paramName])
if err != nil {
return nil, err
}
if _, ok := result[paramName]; ok {
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramName)
continue
}
result[paramName] = paramInfo
}
break // 只取一种 MIME
}
return result, nil
}
func validateOpenapi3RequestBody(bodyRef *openapi3.RequestBodyRef) (err error) {
if bodyRef == nil {
return nil
}
if bodyRef.Value == nil || len(bodyRef.Value.Content) == 0 {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"request body is required"))
}
body := bodyRef.Value
if len(body.Content) != 1 {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"request body only supports one media type"))
}
var mType *openapi3.MediaType
for _, ct := range mediaTypeArray {
var ok bool
mType, ok = body.Content[ct]
if ok {
break
}
}
if mType == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"invalid media type, request body only the following types: [%s]", strings.Join(mediaTypeArray, ", ")))
}
if mType.Schema == nil || mType.Schema.Value == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"request body schema is required"))
}
sc := mType.Schema.Value
if sc.Type == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"request body only supports 'object' type"))
}
if sc.Type != openapi3.TypeObject {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"request body only supports 'object' type"))
}
return nil
}
func validateOpenapi3Parameters(params openapi3.Parameters) (err error) {
if len(params) == 0 {
return nil
}
for _, param := range params {
if param == nil || param.Value == nil || param.Value.Schema == nil || param.Value.Schema.Value == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"parameter schema is required"))
}
paramVal := param.Value
if paramVal.In == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"parameter location is required"))
}
if paramVal.In == string(ParamInBody) {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"the location of parameter '%s' cannot be 'body'", paramVal.Name))
}
paramSchema := paramVal.Schema.Value
if paramSchema.Type == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"the type of parameter '%s' is required", paramVal.Name))
}
if paramSchema.Type == openapi3.TypeObject {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"the type of parameter '%s' cannot be 'object'", paramVal.Name))
}
if paramVal.In == openapi3.ParameterInPath && paramSchema.Type == openapi3.TypeArray {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
"the type of parameter '%s' cannot be 'array'", paramVal.Name))
}
}
return nil
}
// MIME Type
const (
MediaTypeJson = "application/json"
MediaTypeProblemJson = "application/problem+json"
MediaTypeFormURLEncoded = "application/x-www-form-urlencoded"
MediaTypeXYaml = "application/x-yaml"
MediaTypeYaml = "application/yaml"
)
var mediaTypeArray = []string{
MediaTypeJson,
MediaTypeProblemJson,
MediaTypeFormURLEncoded,
MediaTypeXYaml,
MediaTypeYaml,
}
func validateOpenapi3Responses(responses openapi3.Responses) (err error) {
if len(responses) == 0 {
return nil
}
// default status 不处理
// 只处理 '200' status
if len(responses) != 1 {
if len(responses) != 2 {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response only supports '200' status"))
} else if _, ok := responses["default"]; !ok {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response only supports '200' status"))
}
}
resp, ok := responses[strconv.Itoa(http.StatusOK)]
if !ok || resp == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response only supports '200' status"))
}
if resp.Value == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response schema is required"))
}
if len(resp.Value.Content) != 1 {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response only supports 'application/json' media type"))
}
mType, ok := resp.Value.Content[MediaTypeJson]
if !ok || mType == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response only supports 'application/json' media type"))
}
if mType.Schema == nil || mType.Schema.Value == nil {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"the media type schema of response is required"))
}
sc := mType.Schema.Value
if sc.Type == "" {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response body only supports 'object' type"))
}
if sc.Type != openapi3.TypeObject {
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
"response body only supports 'object' type"))
}
return nil
}
func disabledParam(schemaVal *openapi3.Schema) bool {
if len(schemaVal.Extensions) == 0 {
return false
}
globalDisable, localDisable := false, false
if v, ok := schemaVal.Extensions[APISchemaExtendLocalDisable]; ok {
localDisable = v.(bool)
}
if v, ok := schemaVal.Extensions[APISchemaExtendGlobalDisable]; ok {
globalDisable = v.(bool)
}
return globalDisable || localDisable
}

View File

@@ -0,0 +1,51 @@
package plugin
type ExecuteToolOption struct {
ProjectInfo *ProjectInfo
AutoGenRespSchema bool
ToolVersion string
Operation *Openapi3Operation
InvalidRespProcessStrategy InvalidResponseProcessStrategy
}
type ExecuteToolOpt func(o *ExecuteToolOption)
type ProjectInfo struct {
ProjectID int64 // agentID or appID
ProjectVersion *string // if version si nil, use latest version
ProjectType ProjectType // agent or app
ConnectorID int64
}
func WithProjectInfo(info *ProjectInfo) ExecuteToolOpt {
return func(o *ExecuteToolOption) {
o.ProjectInfo = info
}
}
func WithToolVersion(version string) ExecuteToolOpt {
return func(o *ExecuteToolOption) {
o.ToolVersion = version
}
}
func WithOpenapiOperation(op *Openapi3Operation) ExecuteToolOpt {
return func(o *ExecuteToolOption) {
o.Operation = op
}
}
func WithInvalidRespProcessStrategy(strategy InvalidResponseProcessStrategy) ExecuteToolOpt {
return func(o *ExecuteToolOption) {
o.InvalidRespProcessStrategy = strategy
}
}
func WithAutoGenRespSchema() ExecuteToolOpt {
return func(o *ExecuteToolOption) {
o.AutoGenRespSchema = true
}
}

View File

@@ -0,0 +1,153 @@
package plugin
import (
"github.com/getkin/kin-openapi/openapi3"
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
)
type VersionPlugin struct {
PluginID int64
Version string
}
type VersionTool struct {
ToolID int64
Version string
}
type MGetPluginLatestVersionResponse struct {
Versions map[int64]string // pluginID vs version
}
type PluginInfo struct {
ID int64
PluginType api.PluginType
SpaceID int64
DeveloperID int64
APPID *int64
RefProductID *int64 // for product plugin
IconURI *string
ServerURL *string
Version *string
VersionDesc *string
CreatedAt int64
UpdatedAt int64
Manifest *PluginManifest
OpenapiDoc *Openapi3T
}
func (p PluginInfo) SetName(name string) {
if p.Manifest == nil || p.OpenapiDoc == nil {
return
}
p.Manifest.NameForModel = name
p.Manifest.NameForHuman = name
p.OpenapiDoc.Info.Title = name
}
func (p PluginInfo) GetName() string {
if p.Manifest == nil {
return ""
}
return p.Manifest.NameForHuman
}
func (p PluginInfo) GetDesc() string {
if p.Manifest == nil {
return ""
}
return p.Manifest.DescriptionForHuman
}
func (p PluginInfo) GetAuthInfo() *AuthV2 {
if p.Manifest == nil {
return nil
}
return p.Manifest.Auth
}
func (p PluginInfo) IsOfficial() bool {
return p.RefProductID != nil
}
func (p PluginInfo) GetIconURI() string {
if p.IconURI == nil {
return ""
}
return *p.IconURI
}
func (p PluginInfo) Published() bool {
return p.Version != nil
}
type VersionAgentTool struct {
ToolName *string
ToolID int64
AgentVersion *string
}
type MGetAgentToolsRequest struct {
AgentID int64
SpaceID int64
IsDraft bool
VersionAgentTools []VersionAgentTool
}
type ExecuteToolRequest struct {
UserID string
PluginID int64
ToolID int64
ExecDraftTool bool // if true, execute draft tool
ExecScene ExecuteScene
ArgumentsInJson string
}
type ExecuteToolResponse struct {
Tool *ToolInfo
Request string
TrimmedResp string
RawResp string
RespSchema openapi3.Responses
}
type PublishPluginRequest struct {
PluginID int64
Version string
VersionDesc string
}
type PublishAPPPluginsRequest struct {
APPID int64
Version string
}
type PublishAPPPluginsResponse struct {
FailedPlugins []*PluginInfo
AllDraftPlugins []*PluginInfo
}
type CheckCanPublishPluginsRequest struct {
PluginIDs []int64
Version string
}
type CheckCanPublishPluginsResponse struct {
InvalidPlugins []*PluginInfo
}
type ToolInterruptEvent struct {
Event InterruptEventType
ToolNeedOAuth *ToolNeedOAuthInterruptEvent
}
type ToolNeedOAuthInterruptEvent struct {
Message string
}

View File

@@ -0,0 +1,497 @@
package plugin
import (
"encoding/json"
"net/url"
"strings"
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/bytedance/sonic"
)
type PluginManifest struct {
SchemaVersion string `json:"schema_version" yaml:"schema_version"`
NameForModel string `json:"name_for_model" yaml:"name_for_model"`
NameForHuman string `json:"name_for_human" yaml:"name_for_human"`
DescriptionForModel string `json:"description_for_model" yaml:"description_for_model"`
DescriptionForHuman string `json:"description_for_human" yaml:"description_for_human"`
Auth *AuthV2 `json:"auth" yaml:"auth"`
LogoURL string `json:"logo_url" yaml:"logo_url"`
API APIDesc `json:"api" yaml:"api"`
CommonParams map[HTTPParamLocation][]*api.CommonParamSchema `json:"common_params" yaml:"common_params"`
}
func (mf *PluginManifest) Copy() (*PluginManifest, error) {
if mf == nil {
return mf, nil
}
b, err := json.Marshal(mf)
if err != nil {
return nil, err
}
mf_ := &PluginManifest{}
err = json.Unmarshal(b, mf_)
if err != nil {
return nil, err
}
return mf_, err
}
func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
if mf == nil || mf.Auth == nil {
return mf, nil
}
mf_, err := mf.Copy()
if err != nil {
return nil, err
}
if mf_.Auth.Payload == "" {
return mf_, nil
}
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), utils.AuthSecretKey)
if err != nil {
return nil, err
}
mf_.Auth.Payload = payload_
return mf_, nil
}
func (mf *PluginManifest) Validate(skipAuthPayload bool) (err error) {
if mf.SchemaVersion != "v1" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid schema version '%s'", mf.SchemaVersion))
}
if mf.NameForModel == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"name for model is required"))
}
if mf.NameForHuman == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"name for human is required"))
}
if mf.DescriptionForModel == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"description for model is required"))
}
if mf.DescriptionForHuman == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"description for human is required"))
}
if mf.API.Type != PluginTypeOfCloud {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid api type '%s'", mf.API.Type))
}
err = mf.validateAuthInfo(skipAuthPayload)
if err != nil {
return err
}
for loc := range mf.CommonParams {
if loc != ParamInBody &&
loc != ParamInHeader &&
loc != ParamInQuery &&
loc != ParamInPath {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid location '%s' in common params", loc))
}
}
return nil
}
func (mf *PluginManifest) validateAuthInfo(skipAuthPayload bool) (err error) {
if mf.Auth == nil {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"auth is required"))
}
if mf.Auth.Payload != "" {
js := json.RawMessage{}
err = sonic.UnmarshalString(mf.Auth.Payload, &js)
if err != nil {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload"))
}
}
if mf.Auth.Type == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"auth type is required"))
}
if mf.Auth.Type != AuthzTypeOfNone &&
mf.Auth.Type != AuthzTypeOfOAuth &&
mf.Auth.Type != AuthzTypeOfService {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid auth type '%s'", mf.Auth.Type))
}
if mf.Auth.Type == AuthzTypeOfNone {
return nil
}
if mf.Auth.SubType == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"sub-auth type is required"))
}
switch mf.Auth.SubType {
case AuthzSubTypeOfServiceAPIToken:
err = mf.validateServiceToken(skipAuthPayload)
//case AuthzSubTypeOfOAuthClientCredentials:
// err = mf.validateClientCredentials()
case AuthzSubTypeOfOAuthAuthorizationCode:
err = mf.validateAuthCode(skipAuthPayload)
default:
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid sub-auth type '%s'", mf.Auth.SubType))
}
if err != nil {
return err
}
return nil
}
func (mf *PluginManifest) validateServiceToken(skipAuthPayload bool) (err error) {
if mf.Auth.AuthOfAPIToken == nil {
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfAPIToken)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload"))
}
}
if skipAuthPayload {
return nil
}
apiToken := mf.Auth.AuthOfAPIToken
if apiToken.ServiceToken == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"service token is required"))
}
if apiToken.Key == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"key is required"))
}
loc := HTTPParamLocation(strings.ToLower(string(apiToken.Location)))
if loc != ParamInHeader && loc != ParamInQuery {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid location '%s'", apiToken.Location))
}
return nil
}
func (mf *PluginManifest) validateClientCredentials() (err error) {
if mf.Auth.AuthOfOAuthClientCredentials == nil {
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthClientCredentials)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload"))
}
}
clientCredentials := mf.Auth.AuthOfOAuthClientCredentials
if clientCredentials.ClientID == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"client id is required"))
}
if clientCredentials.ClientSecret == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"client secret is required"))
}
if clientCredentials.TokenURL == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"token url is required"))
}
urlParse, err := url.Parse(clientCredentials.TokenURL)
if err != nil || urlParse.Hostname() == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid token url"))
}
if urlParse.Scheme != "https" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"token url scheme must be 'https'"))
}
return nil
}
func (mf *PluginManifest) validateAuthCode(skipAuthPayload bool) (err error) {
if mf.Auth.AuthOfOAuthAuthorizationCode == nil {
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthAuthorizationCode)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload"))
}
}
if skipAuthPayload {
return nil
}
authCode := mf.Auth.AuthOfOAuthAuthorizationCode
if authCode.ClientID == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"client id is required"))
}
if authCode.ClientSecret == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"client secret is required"))
}
if authCode.AuthorizationContentType != MediaTypeJson {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"authorization content type must be 'application/json'"))
}
if authCode.AuthorizationURL == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"token url is required"))
}
if authCode.ClientURL == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"client url is required"))
}
urlParse, err := url.Parse(authCode.AuthorizationURL)
if err != nil || urlParse.Hostname() == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid authorization url"))
}
if urlParse.Scheme != "https" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"authorization url scheme must be 'https'"))
}
urlParse, err = url.Parse(authCode.ClientURL)
if err != nil || urlParse.Hostname() == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid client url"))
}
if urlParse.Scheme != "https" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"client url scheme must be 'https'"))
}
return nil
}
type Auth struct {
Type string `json:"type" validate:"required"`
AuthorizationType string `json:"authorization_type,omitempty"`
ClientURL string `json:"client_url,omitempty"`
Scope string `json:"scope,omitempty"`
AuthorizationURL string `json:"authorization_url,omitempty"`
AuthorizationContentType string `json:"authorization_content_type,omitempty"`
Platform string `json:"platform,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
Location string `json:"location,omitempty"`
Key string `json:"key,omitempty"`
ServiceToken string `json:"service_token,omitempty"`
SubType string `json:"sub_type"`
Payload string `json:"payload"`
}
type AuthV2 struct {
Type AuthzType `json:"type" yaml:"type"`
SubType AuthzSubType `json:"sub_type" yaml:"sub_type"`
Payload string `json:"payload" yaml:"payload"`
// service
AuthOfAPIToken *AuthOfAPIToken `json:"-"`
// oauth
AuthOfOAuthAuthorizationCode *OAuthAuthorizationCodeConfig `json:"-"`
AuthOfOAuthClientCredentials *OAuthClientCredentialsConfig `json:"-"`
}
func (au *AuthV2) UnmarshalJSON(data []byte) error {
auth := &Auth{} // 兼容老数据
err := json.Unmarshal(data, auth)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid plugin manifest json"))
}
au.Type = AuthzType(auth.Type)
au.SubType = AuthzSubType(auth.SubType)
if au.Type == "" {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"plugin auth type is required"))
}
if auth.Payload != "" {
payload_, err := utils.DecryptByAES(auth.Payload, utils.AuthSecretKey)
if err == nil {
auth.Payload = string(payload_)
}
}
switch au.Type {
case AuthzTypeOfNone:
case AuthzTypeOfOAuth:
err = au.unmarshalOAuth(auth)
case AuthzTypeOfService:
err = au.unmarshalService(auth)
default:
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid plugin auth type '%s'", au.Type))
}
if err != nil {
return err
}
return nil
}
func (au *AuthV2) unmarshalService(auth *Auth) (err error) {
if au.SubType == "" && au.Payload == "" { // 兼容老数据
au.SubType = AuthzSubTypeOfServiceAPIToken
}
var payload []byte
if au.SubType == AuthzSubTypeOfServiceAPIToken {
if len(auth.ServiceToken) > 0 {
au.AuthOfAPIToken = &AuthOfAPIToken{
Location: HTTPParamLocation(strings.ToLower(auth.Location)),
Key: auth.Key,
ServiceToken: auth.ServiceToken,
}
} else {
token := &AuthOfAPIToken{}
err = json.Unmarshal([]byte(auth.Payload), token)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload json"))
}
au.AuthOfAPIToken = token
}
payload, err = json.Marshal(au.AuthOfAPIToken)
if err != nil {
return err
}
}
if len(payload) == 0 {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid plugin sub-auth type '%s'", au.SubType))
}
au.Payload = string(payload)
return nil
}
func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) {
if au.SubType == "" { // 兼容老数据
au.SubType = AuthzSubTypeOfOAuthAuthorizationCode
}
var payload []byte
if au.SubType == AuthzSubTypeOfOAuthAuthorizationCode {
if len(auth.ClientSecret) > 0 {
au.AuthOfOAuthAuthorizationCode = &OAuthAuthorizationCodeConfig{
ClientID: auth.ClientID,
ClientSecret: auth.ClientSecret,
ClientURL: auth.ClientURL,
Scope: auth.Scope,
AuthorizationURL: auth.AuthorizationURL,
AuthorizationContentType: auth.AuthorizationContentType,
}
} else {
oauth := &OAuthAuthorizationCodeConfig{}
err = json.Unmarshal([]byte(auth.Payload), oauth)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload json"))
}
au.AuthOfOAuthAuthorizationCode = oauth
}
payload, err = json.Marshal(au.AuthOfOAuthAuthorizationCode)
if err != nil {
return err
}
}
if au.SubType == AuthzSubTypeOfOAuthClientCredentials {
oauth := &OAuthClientCredentialsConfig{}
err = json.Unmarshal([]byte(auth.Payload), oauth)
if err != nil {
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
"invalid auth payload json"))
}
au.AuthOfOAuthClientCredentials = oauth
payload, err = json.Marshal(au.AuthOfOAuthClientCredentials)
if err != nil {
return err
}
}
if len(payload) == 0 {
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
"invalid plugin sub-auth type '%s'", au.SubType))
}
au.Payload = string(payload)
return nil
}
type AuthOfAPIToken struct {
// Location is the location of the parameter.
// It can be "header" or "query".
Location HTTPParamLocation `json:"location"`
// Key is the name of the parameter.
Key string `json:"key"`
// ServiceToken is the simple authorization information for the service.
ServiceToken string `json:"service_token"`
}
type OAuthAuthorizationCodeConfig struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
// ClientURL is the URL of authorization endpoint.
ClientURL string `json:"client_url"`
// Scope is the scope of the authorization request.
// If multiple scopes are requested, they must be separated by a space.
Scope string `json:"scope,omitempty"`
// AuthorizationURL is the URL of token exchange endpoint.
AuthorizationURL string `json:"authorization_url"`
// AuthorizationContentType is the content type of the authorization request, and it must be "application/json".
AuthorizationContentType string `json:"authorization_content_type"`
}
type OAuthClientCredentialsConfig struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
TokenURL string `json:"token_url"`
}
type APIDesc struct {
Type PluginType `json:"type" validate:"required"`
}

View File

@@ -0,0 +1,566 @@
package plugin
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/getkin/kin-openapi/openapi3"
gonanoid "github.com/matoous/go-nanoid"
productAPI "github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type ToolInfo struct {
ID int64
PluginID int64
CreatedAt int64
UpdatedAt int64
Version *string
ActivatedStatus *ActivatedStatus
DebugStatus *plugin_develop_common.APIDebugStatus
Method *string
SubURL *string
Operation *Openapi3Operation
}
func (t ToolInfo) GetName() string {
if t.Operation == nil {
return ""
}
return t.Operation.OperationID
}
func (t ToolInfo) GetDesc() string {
if t.Operation == nil {
return ""
}
return t.Operation.Summary
}
func (t ToolInfo) GetVersion() string {
return ptr.FromOrDefault(t.Version, "")
}
func (t ToolInfo) GetActivatedStatus() ActivatedStatus {
return ptr.FromOrDefault(t.ActivatedStatus, ActivateTool)
}
func (t ToolInfo) GetSubURL() string {
return ptr.FromOrDefault(t.SubURL, "")
}
func (t ToolInfo) GetMethod() string {
return strings.ToUpper(ptr.FromOrDefault(t.Method, ""))
}
func (t ToolInfo) GetDebugStatus() common.APIDebugStatus {
return ptr.FromOrDefault(t.DebugStatus, common.APIDebugStatus_DebugWaiting)
}
func (t ToolInfo) GetResponseOpenapiSchema() (*openapi3.Schema, error) {
op := t.Operation
if op == nil {
return nil, fmt.Errorf("operation is required")
}
resp, ok := op.Responses[strconv.Itoa(http.StatusOK)]
if !ok || resp == nil || resp.Value == nil || len(resp.Value.Content) == 0 {
return nil, fmt.Errorf("response status '200' not found")
}
mType, ok := resp.Value.Content[MediaTypeJson] // only support application/json
if !ok || mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
return nil, fmt.Errorf("media type '%s' not found in response", MediaTypeJson)
}
return mType.Schema.Value, nil
}
type paramMetaInfo struct {
name string
desc string
required bool
location string
}
func (t ToolInfo) ToRespAPIParameter() ([]*common.APIParameter, error) {
op := t.Operation
if op == nil {
return nil, fmt.Errorf("operation is required")
}
respSchema, err := t.GetResponseOpenapiSchema()
if err != nil {
return nil, err
}
params := make([]*common.APIParameter, 0, len(op.Parameters))
if len(respSchema.Properties) == 0 {
return params, nil
}
required := slices.ToMap(respSchema.Required, func(e string) (string, bool) {
return e, true
})
for subParamName, prop := range respSchema.Properties {
if prop == nil || prop.Value == nil {
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
}
paramMeta := paramMetaInfo{
name: subParamName,
desc: prop.Value.Description,
location: string(ParamInBody),
required: required[subParamName],
}
apiParam, err := toAPIParameter(paramMeta, prop.Value)
if err != nil {
return nil, err
}
params = append(params, apiParam)
}
return params, nil
}
func (t ToolInfo) ToReqAPIParameter() ([]*common.APIParameter, error) {
op := t.Operation
if op == nil {
return nil, fmt.Errorf("operation is required")
}
params := make([]*common.APIParameter, 0, len(op.Parameters))
for _, param := range op.Parameters {
if param == nil || param.Value == nil || param.Value.Schema == nil || param.Value.Schema.Value == nil {
return nil, fmt.Errorf("parameter schema is required")
}
paramVal := param.Value
schemaVal := paramVal.Schema.Value
if schemaVal.Type == openapi3.TypeObject {
return nil, fmt.Errorf("the type of parameter '%s' cannot be 'object'", paramVal.Name)
}
if schemaVal.Type == openapi3.TypeArray {
if paramVal.In == openapi3.ParameterInPath {
return nil, fmt.Errorf("the type of field '%s' cannot be 'array'", paramVal.Name)
}
if schemaVal.Items == nil || schemaVal.Items.Value == nil {
return nil, fmt.Errorf("the item schema of field '%s' is required", paramVal.Name)
}
item := schemaVal.Items.Value
if item.Type == openapi3.TypeObject || item.Type == openapi3.TypeArray {
return nil, fmt.Errorf("the item type of parameter '%s' cannot be 'object' or 'array'", paramVal.Name)
}
}
paramMeta := paramMetaInfo{
name: paramVal.Name,
desc: paramVal.Description,
location: paramVal.In,
required: paramVal.Required,
}
apiParam, err := toAPIParameter(paramMeta, schemaVal)
if err != nil {
return nil, err
}
params = append(params, apiParam)
}
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
return params, nil
}
for _, mType := range op.RequestBody.Value.Content {
if mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
return nil, fmt.Errorf("request body schema is required")
}
schemaVal := mType.Schema.Value
if len(schemaVal.Properties) == 0 {
continue
}
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
return e, true
})
for subParamName, prop := range schemaVal.Properties {
if prop == nil || prop.Value == nil {
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
}
paramMeta := paramMetaInfo{
name: subParamName,
desc: prop.Value.Description,
location: string(ParamInBody),
required: required[subParamName],
}
apiParam, err := toAPIParameter(paramMeta, prop.Value)
if err != nil {
return nil, err
}
params = append(params, apiParam)
}
break // 只取一种 MIME
}
return params, nil
}
func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIParameter, error) {
if sc == nil {
return nil, fmt.Errorf("schema is requred")
}
apiType, ok := ToThriftParamType(strings.ToLower(sc.Type))
if !ok {
return nil, fmt.Errorf("the type '%s' of filed '%s' is invalid", sc.Type, paramMeta.name)
}
location, ok := ToThriftHTTPParamLocation(HTTPParamLocation(paramMeta.location))
if !ok {
return nil, fmt.Errorf("the location '%s' of field '%s' is invalid", paramMeta.location, paramMeta.name)
}
apiParam := &common.APIParameter{
ID: gonanoid.MustID(10),
Name: paramMeta.name,
Desc: paramMeta.desc,
Type: apiType,
Location: location, // 使用父节点的值
IsRequired: paramMeta.required,
SubParameters: []*common.APIParameter{},
}
if sc.Default != nil {
apiParam.LocalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
}
if sc.Format != "" {
aType, ok := FormatToAssistType(sc.Format)
if !ok {
return nil, fmt.Errorf("the format '%s' of field '%s' is invalid", sc.Format, paramMeta.name)
}
_aType, ok := ToThriftAPIAssistType(aType)
if !ok {
return nil, fmt.Errorf("assist type '%s' of field '%s' is invalid", aType, paramMeta.name)
}
apiParam.AssistType = ptr.Of(_aType)
}
if v, ok := sc.Extensions[APISchemaExtendGlobalDisable]; ok {
if disable, ok := v.(bool); ok {
apiParam.GlobalDisable = disable
}
}
if v, ok := sc.Extensions[APISchemaExtendLocalDisable]; ok {
if disable, ok := v.(bool); ok {
apiParam.LocalDisable = disable
}
}
if v, ok := sc.Extensions[APISchemaExtendVariableRef]; ok {
if ref, ok := v.(string); ok {
apiParam.VariableRef = ptr.Of(ref)
apiParam.DefaultParamSource = ptr.Of(common.DefaultParamSource_Variable)
}
}
switch sc.Type {
case openapi3.TypeObject:
if len(sc.Properties) == 0 {
return apiParam, nil
}
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
return e, true
})
for subParamName, prop := range sc.Properties {
if prop == nil || prop.Value == nil {
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
}
subMeta := paramMetaInfo{
name: subParamName,
desc: prop.Value.Description,
required: required[subParamName],
location: paramMeta.location,
}
subParam, err := toAPIParameter(subMeta, prop.Value)
if err != nil {
return nil, err
}
apiParam.SubParameters = append(apiParam.SubParameters, subParam)
}
return apiParam, nil
case openapi3.TypeArray:
if sc.Items == nil || sc.Items.Value == nil {
return nil, fmt.Errorf("the item schema of field '%s' is required", paramMeta.name)
}
item := sc.Items.Value
subMeta := paramMetaInfo{
name: "[Array Item]",
desc: item.Description,
location: paramMeta.location,
required: paramMeta.required,
}
subParam, err := toAPIParameter(subMeta, item)
if err != nil {
return nil, err
}
apiParam.SubParameters = append(apiParam.SubParameters, subParam)
return apiParam, nil
}
return apiParam, nil
}
func (t ToolInfo) ToPluginParameters() ([]*common.PluginParameter, error) {
op := t.Operation
if op == nil {
return nil, fmt.Errorf("operation is required")
}
var params []*common.PluginParameter
for _, prop := range op.Parameters {
if prop == nil || prop.Value == nil || prop.Value.Schema == nil || prop.Value.Schema.Value == nil {
return nil, fmt.Errorf("parameter schema is required")
}
paramVal := prop.Value
schemaVal := paramVal.Schema.Value
if schemaVal.Type == openapi3.TypeObject {
return nil, fmt.Errorf("the type of parameter '%s' cannot be 'object'", paramVal.Name)
}
var arrayItemType string
if schemaVal.Type == openapi3.TypeArray {
if paramVal.In == openapi3.ParameterInPath {
return nil, fmt.Errorf("the type of field '%s' cannot be 'array'", paramVal.Name)
}
if schemaVal.Items == nil || schemaVal.Items.Value == nil {
return nil, fmt.Errorf("the item schema of field '%s' is required", paramVal.Name)
}
item := schemaVal.Items.Value
if item.Type == openapi3.TypeObject || item.Type == openapi3.TypeArray {
return nil, fmt.Errorf("the item type of parameter '%s' cannot be 'object' or 'array'", paramVal.Name)
}
arrayItemType = item.Type
}
if disabledParam(schemaVal) {
continue
}
var assistType *common.PluginParamTypeFormat
if v, ok := schemaVal.Extensions[APISchemaExtendAssistType]; ok {
_v, ok := v.(string)
if !ok {
continue
}
f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v))
if ok {
return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramVal.Name)
}
assistType = ptr.Of(f)
}
params = append(params, &common.PluginParameter{
Name: paramVal.Name,
Desc: paramVal.Description,
Required: paramVal.Required,
Type: schemaVal.Type,
SubType: arrayItemType,
Format: assistType,
SubParameters: []*common.PluginParameter{},
})
}
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
return params, nil
}
for _, mType := range op.RequestBody.Value.Content {
if mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
return nil, fmt.Errorf("request body schema is required")
}
schemaVal := mType.Schema.Value
if len(schemaVal.Properties) == 0 {
continue
}
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
return e, true
})
for subParamName, prop := range schemaVal.Properties {
if prop == nil || prop.Value == nil {
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
}
paramMeta := paramMetaInfo{
name: subParamName,
desc: prop.Value.Description,
required: required[subParamName],
}
paramInfo, err := toPluginParameter(paramMeta, prop.Value)
if err != nil {
return nil, err
}
if paramInfo != nil {
params = append(params, paramInfo)
}
}
break // 只取一种 MIME
}
return params, nil
}
func toPluginParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.PluginParameter, error) {
if sc == nil {
return nil, fmt.Errorf("schema is required")
}
if disabledParam(sc) {
return nil, nil
}
var assistType *common.PluginParamTypeFormat
if v, ok := sc.Extensions[APISchemaExtendAssistType]; ok {
if _v, ok := v.(string); ok {
f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v))
if !ok {
return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramMeta.name)
}
assistType = ptr.Of(f)
}
}
pluginParam := &common.PluginParameter{
Name: paramMeta.name,
Type: sc.Type,
Desc: paramMeta.desc,
Required: paramMeta.required,
Format: assistType,
SubParameters: []*common.PluginParameter{},
}
switch sc.Type {
case openapi3.TypeObject:
if len(sc.Properties) == 0 {
return pluginParam, nil
}
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
return e, true
})
for subParamName, prop := range sc.Properties {
if prop == nil || prop.Value == nil {
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
}
subMeta := paramMetaInfo{
name: subParamName,
desc: prop.Value.Description,
required: required[subParamName],
}
subParam, err := toPluginParameter(subMeta, prop.Value)
if err != nil {
return nil, err
}
pluginParam.SubParameters = append(pluginParam.SubParameters, subParam)
}
return pluginParam, nil
case openapi3.TypeArray:
if sc.Items == nil || sc.Items.Value == nil {
return nil, fmt.Errorf("the item schema of field '%s' is required", paramMeta.name)
}
item := sc.Items.Value
pluginParam.SubType = item.Type
if item.Type != openapi3.TypeObject {
return pluginParam, nil
}
subMeta := paramMetaInfo{
desc: item.Description,
required: paramMeta.required,
}
subParam, err := toPluginParameter(subMeta, item)
if err != nil {
return nil, err
}
pluginParam.SubParameters = append(pluginParam.SubParameters, subParam.SubParameters...)
return pluginParam, nil
}
return pluginParam, nil
}
func (t ToolInfo) ToToolParameters() ([]*productAPI.ToolParameter, error) {
apiParams, err := t.ToReqAPIParameter()
if err != nil {
return nil, err
}
var toToolParams func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error)
toToolParams = func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error) {
params := make([]*productAPI.ToolParameter, 0, len(apiParams))
for _, apiParam := range apiParams {
typ, _ := ToOpenapiParamType(apiParam.Type)
toolParam := &productAPI.ToolParameter{
Name: apiParam.Name,
Description: apiParam.Desc,
Type: typ,
IsRequired: apiParam.IsRequired,
SubParameter: []*productAPI.ToolParameter{},
}
if len(apiParam.SubParameters) > 0 {
subParams, err := toToolParams(apiParam.SubParameters)
if err != nil {
return nil, err
}
toolParam.SubParameter = append(toolParam.SubParameter, subParams...)
}
params = append(params, toolParam)
}
return params, nil
}
return toToolParams(apiParams)
}