fix(backend_plugin): plugin common header and parameter default value (#181)
This commit is contained in:
parent
8137b0aee5
commit
53345f58c2
|
|
@ -263,6 +263,7 @@ func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIPa
|
|||
}
|
||||
|
||||
if sc.Default != nil {
|
||||
apiParam.GlobalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
|
||||
apiParam.LocalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,8 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
|
|
@ -33,14 +32,6 @@ import (
|
|||
func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (*openapi3.Operation, error) {
|
||||
op := &openapi3.Operation{}
|
||||
|
||||
if reqParams != nil && len(reqParams) == 0 {
|
||||
op.Parameters = []*openapi3.ParameterRef{}
|
||||
op.RequestBody = entity.DefaultOpenapi3RequestBody()
|
||||
}
|
||||
if respParams != nil && len(respParams) == 0 {
|
||||
op.Responses = entity.DefaultOpenapi3Responses()
|
||||
}
|
||||
|
||||
hasSetReqBody := false
|
||||
hasSetParams := false
|
||||
|
||||
|
|
@ -136,6 +127,16 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (
|
|||
}
|
||||
}
|
||||
|
||||
if op.Parameters == nil {
|
||||
op.Parameters = []*openapi3.ParameterRef{}
|
||||
}
|
||||
if op.RequestBody == nil {
|
||||
op.RequestBody = entity.DefaultOpenapi3RequestBody()
|
||||
}
|
||||
if op.Responses == nil {
|
||||
op.Responses = entity.DefaultOpenapi3Responses()
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -157,7 +157,6 @@ func NewDefaultPluginManifest() *PluginManifest {
|
|||
Value: "Coze/1.0",
|
||||
},
|
||||
},
|
||||
model.ParamInPath: {},
|
||||
model.ParamInQuery: {},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
|
|
@ -479,11 +480,6 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
|
|||
return nil, err
|
||||
}
|
||||
|
||||
requestStr, err := sonic.MarshalString(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := t.buildHTTPRequest(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -504,18 +500,29 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
|
|||
}
|
||||
|
||||
var reqBodyBytes []byte
|
||||
if httpReq.Body != nil {
|
||||
reqBodyBytes, err = io.ReadAll(httpReq.Body)
|
||||
if httpReq.GetBody != nil {
|
||||
reqBody, err := httpReq.GetBody()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reqBody.Close()
|
||||
|
||||
reqBodyBytes, err = io.ReadAll(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
requestStr, err := genRequestString(httpReq, reqBodyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restyReq := t.svc.httpCli.NewRequest()
|
||||
restyReq.Header = httpReq.Header
|
||||
restyReq.Method = httpReq.Method
|
||||
restyReq.URL = httpReq.URL.String()
|
||||
if len(reqBodyBytes) > 0 {
|
||||
if reqBodyBytes != nil {
|
||||
restyReq.SetBody(reqBodyBytes)
|
||||
}
|
||||
restyReq.SetContext(ctx)
|
||||
|
|
@ -559,6 +566,46 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
|
|||
}, nil
|
||||
}
|
||||
|
||||
func genRequestString(req *http.Request, body []byte) (string, error) {
|
||||
type Request struct {
|
||||
Path string `json:"path"`
|
||||
Header map[string]string `json:"header"`
|
||||
Query map[string]string `json:"query"`
|
||||
Body *[]byte `json:"body"`
|
||||
}
|
||||
|
||||
req_ := &Request{
|
||||
Path: req.URL.Path,
|
||||
Header: map[string]string{},
|
||||
Query: map[string]string{},
|
||||
}
|
||||
|
||||
if len(req.Header) > 0 {
|
||||
for k, v := range req.Header {
|
||||
req_.Header[k] = v[0]
|
||||
}
|
||||
}
|
||||
if len(req.URL.Query()) > 0 {
|
||||
for k, v := range req.URL.Query() {
|
||||
req_.Query[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
requestStr, err := sonic.MarshalString(req_)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return requestStr, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) preprocessArgumentsInJson(ctx context.Context, argumentsInJson string) (args map[string]any, err error) {
|
||||
args, err = t.prepareArguments(ctx, argumentsInJson)
|
||||
if err != nil {
|
||||
|
|
@ -653,23 +700,13 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
|
|||
return nil, err
|
||||
}
|
||||
|
||||
reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL)
|
||||
commonParams := t.plugin.Manifest.CommonParams
|
||||
|
||||
reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL, commonParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := locArgs.buildHTTPRequestHeader(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header = header
|
||||
|
||||
bodyArgs := map[string]any{}
|
||||
for k, v := range argMaps {
|
||||
if _, ok := locArgs.header[k]; ok {
|
||||
|
|
@ -684,13 +721,27 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
|
|||
bodyArgs[k] = v
|
||||
}
|
||||
|
||||
bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs)
|
||||
commonBody := commonParams[model.ParamInBody]
|
||||
bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs, commonBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commonHeader := commonParams[model.ParamInHeader]
|
||||
header, err := locArgs.buildHTTPRequestHeader(ctx, commonHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header = header
|
||||
|
||||
if len(bodyBytes) > 0 {
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
httpReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
return httpReq, nil
|
||||
|
|
@ -698,13 +749,6 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
|
|||
|
||||
func (t *toolExecutor) prepareArguments(_ context.Context, argumentsInJson string) (map[string]any, error) {
|
||||
args := map[string]any{}
|
||||
for loc, params := range t.plugin.Manifest.CommonParams {
|
||||
for _, p := range params {
|
||||
if loc != model.ParamInBody {
|
||||
args[p.Name] = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
decoder := sonic.ConfigDefault.NewDecoder(bytes.NewBufferString(argumentsInJson))
|
||||
decoder.UseNumber()
|
||||
|
|
@ -1175,7 +1219,9 @@ type valueWithSchema struct {
|
|||
paramSchema *openapi3.Parameter
|
||||
}
|
||||
|
||||
func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string) (reqURL *url.URL, err error) {
|
||||
func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string,
|
||||
commonParams map[model.HTTPParamLocation][]*common.CommonParamSchema) (reqURL *url.URL, err error) {
|
||||
|
||||
if len(l.path) > 0 {
|
||||
for k, v := range l.path {
|
||||
vStr, err := encoder.EncodeParameter(v.paramSchema, v.argValue)
|
||||
|
|
@ -1186,9 +1232,8 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string
|
|||
}
|
||||
}
|
||||
|
||||
encodeQuery := ""
|
||||
query := url.Values{}
|
||||
if len(l.query) > 0 {
|
||||
query := url.Values{}
|
||||
for k, val := range l.query {
|
||||
switch v := val.argValue.(type) {
|
||||
case []any:
|
||||
|
|
@ -1199,10 +1244,18 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string
|
|||
query.Add(k, encoder.MustString(v))
|
||||
}
|
||||
}
|
||||
|
||||
encodeQuery = query.Encode()
|
||||
}
|
||||
|
||||
commonQuery := commonParams[model.ParamInQuery]
|
||||
for _, v := range commonQuery {
|
||||
if _, ok := l.query[v.Name]; ok {
|
||||
continue
|
||||
}
|
||||
query.Add(v.Name, v.Value)
|
||||
}
|
||||
|
||||
encodeQuery := query.Encode()
|
||||
|
||||
reqURL, err = url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -1217,7 +1270,7 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string
|
|||
return reqURL, nil
|
||||
}
|
||||
|
||||
func (l *locationArguments) buildHTTPRequestHeader(_ context.Context) (http.Header, error) {
|
||||
func (l *locationArguments) buildHTTPRequestHeader(_ context.Context, commonHeaders []*common.CommonParamSchema) (http.Header, error) {
|
||||
header := http.Header{}
|
||||
if len(l.header) > 0 {
|
||||
for k, v := range l.header {
|
||||
|
|
@ -1232,44 +1285,64 @@ func (l *locationArguments) buildHTTPRequestHeader(_ context.Context) (http.Head
|
|||
}
|
||||
}
|
||||
|
||||
for _, h := range commonHeaders {
|
||||
if header.Get(h.Name) != "" {
|
||||
continue
|
||||
}
|
||||
header.Add(h.Name, h.Value)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any) (body []byte, contentType string, err error) {
|
||||
func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any,
|
||||
commonBody []*common.CommonParamSchema) (body []byte, contentType string, err error) {
|
||||
|
||||
var bodyMap map[string]any
|
||||
|
||||
contentType, bodySchema := t.getReqBodySchema(op)
|
||||
if bodySchema == nil || bodySchema.Value == nil {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
if len(bodySchema.Value.Properties) == 0 {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
bodyMap, err := t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
for paramName, prop := range bodySchema.Value.Properties {
|
||||
value, ok := bodyMap[paramName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_value, err := encoder.TryFixValueType(paramName, prop, value)
|
||||
if bodySchema != nil && len(bodySchema.Value.Properties) > 0 {
|
||||
bodyMap, err = t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
bodyMap[paramName] = _value
|
||||
for paramName, prop := range bodySchema.Value.Properties {
|
||||
value, ok := bodyMap[paramName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_value, err := encoder.TryFixValueType(paramName, prop, value)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
bodyMap[paramName] = _value
|
||||
}
|
||||
|
||||
body, err = encoder.EncodeBodyWithContentType(contentType, bodyMap)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
reqBodyStr, err := encoder.EncodeBodyWithContentType(contentType, bodyMap)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
|
||||
commonBody_ := make([]*common.CommonParamSchema, 0, len(commonBody))
|
||||
for _, v := range commonBody {
|
||||
if _, ok := bodyMap[v.Name]; ok {
|
||||
continue
|
||||
}
|
||||
commonBody_ = append(commonBody_, v)
|
||||
}
|
||||
|
||||
return reqBodyStr, contentType, nil
|
||||
for _, v := range commonBody_ {
|
||||
body, err = sjson.SetRawBytes(body, v.Name, []byte(v.Value))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("[buildRequestBody] SetRawBytes failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return body, contentType, nil
|
||||
}
|
||||
|
||||
func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *openapi3.Schema, vals map[string]any) (newVals map[string]any, err error) {
|
||||
|
|
@ -1327,7 +1400,7 @@ func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *op
|
|||
}
|
||||
|
||||
func (t *toolExecutor) getReqBodySchema(op *model.Openapi3Operation) (string, *openapi3.SchemaRef) {
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
if op.RequestBody == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
. "github.com/bytedance/mockey"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenRequestString(t *testing.T) {
|
||||
PatchConvey("", t, func() {
|
||||
requestStr, err := genRequestString(&http.Request{
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/test"},
|
||||
}, []byte(`{"a": 1}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr)
|
||||
})
|
||||
|
||||
PatchConvey("", t, func() {
|
||||
var body []byte
|
||||
requestStr, err := genRequestString(&http.Request{
|
||||
URL: &url.URL{Path: "/test"},
|
||||
}, body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr)
|
||||
})
|
||||
}
|
||||
|
|
@ -46,6 +46,7 @@ import (
|
|||
|
||||
func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDraftPluginRequest) (pluginID int64, err error) {
|
||||
mf := entity.NewDefaultPluginManifest()
|
||||
mf.CommonParams = map[model.HTTPParamLocation][]*plugin_develop_common.CommonParamSchema{}
|
||||
mf.NameForHuman = req.Name
|
||||
mf.NameForModel = req.Name
|
||||
mf.DescriptionForHuman = req.Desc
|
||||
|
|
@ -65,11 +66,11 @@ func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDr
|
|||
return 0, fmt.Errorf("invalid location '%s'", loc.String())
|
||||
}
|
||||
for _, param := range params {
|
||||
mParams := mf.CommonParams[location]
|
||||
mParams = append(mParams, &plugin_develop_common.CommonParamSchema{
|
||||
Name: param.Name,
|
||||
Value: param.Value,
|
||||
})
|
||||
mf.CommonParams[location] = append(mf.CommonParams[location],
|
||||
&plugin_develop_common.CommonParamSchema{
|
||||
Name: param.Name,
|
||||
Value: param.Value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Contex
|
|||
meta := ci.Meta
|
||||
info, exist, err := p.oauthRepo.GetAuthorizationCode(ctx, ci.Meta)
|
||||
if err != nil {
|
||||
return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%p",
|
||||
return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%t",
|
||||
meta.UserID, meta.PluginID, meta.IsDraft)
|
||||
}
|
||||
if !exist {
|
||||
|
|
|
|||
Loading…
Reference in New Issue