diff --git a/backend/api/model/crossdomain/plugin/toolinfo.go b/backend/api/model/crossdomain/plugin/toolinfo.go index 8cdc65e4..8377c6cd 100644 --- a/backend/api/model/crossdomain/plugin/toolinfo.go +++ b/backend/api/model/crossdomain/plugin/toolinfo.go @@ -263,6 +263,7 @@ func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIPa } if sc.Default != nil { + apiParam.GlobalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default)) apiParam.LocalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default)) } diff --git a/backend/application/base/pluginutil/api.go b/backend/application/base/pluginutil/api.go index 051d17a1..79b42e80 100644 --- a/backend/application/base/pluginutil/api.go +++ b/backend/application/base/pluginutil/api.go @@ -20,9 +20,8 @@ import ( "net/http" "strconv" - "github.com/getkin/kin-openapi/openapi3" - "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/getkin/kin-openapi/openapi3" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" @@ -33,14 +32,6 @@ import ( func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (*openapi3.Operation, error) { op := &openapi3.Operation{} - if reqParams != nil && len(reqParams) == 0 { - op.Parameters = []*openapi3.ParameterRef{} - op.RequestBody = entity.DefaultOpenapi3RequestBody() - } - if respParams != nil && len(respParams) == 0 { - op.Responses = entity.DefaultOpenapi3Responses() - } - hasSetReqBody := false hasSetParams := false @@ -136,6 +127,16 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) ( } } + if op.Parameters == nil { + op.Parameters = []*openapi3.ParameterRef{} + } + if op.RequestBody == nil { + op.RequestBody = entity.DefaultOpenapi3RequestBody() + } + if op.Responses == nil { + op.Responses = entity.DefaultOpenapi3Responses() + } + return op, nil } diff --git a/backend/domain/plugin/entity/plugin.go b/backend/domain/plugin/entity/plugin.go index 665b7ad4..c774104a 100644 --- a/backend/domain/plugin/entity/plugin.go +++ b/backend/domain/plugin/entity/plugin.go @@ -157,7 +157,6 @@ func NewDefaultPluginManifest() *PluginManifest { Value: "Coze/1.0", }, }, - model.ParamInPath: {}, model.ParamInQuery: {}, }, } diff --git a/backend/domain/plugin/service/exec_tool.go b/backend/domain/plugin/service/exec_tool.go index 8aed1f9b..75553e1e 100644 --- a/backend/domain/plugin/service/exec_tool.go +++ b/backend/domain/plugin/service/exec_tool.go @@ -29,6 +29,7 @@ import ( "github.com/bytedance/sonic" "github.com/getkin/kin-openapi/openapi3" + "github.com/tidwall/sjson" einoCompose "github.com/cloudwego/eino/compose" @@ -479,11 +480,6 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res return nil, err } - requestStr, err := sonic.MarshalString(args) - if err != nil { - return nil, err - } - httpReq, err := t.buildHTTPRequest(ctx, args) if err != nil { return nil, err @@ -504,18 +500,29 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res } var reqBodyBytes []byte - if httpReq.Body != nil { - reqBodyBytes, err = io.ReadAll(httpReq.Body) + if httpReq.GetBody != nil { + reqBody, err := httpReq.GetBody() if err != nil { return nil, err } + defer reqBody.Close() + + reqBodyBytes, err = io.ReadAll(reqBody) + if err != nil { + return nil, err + } + } + + requestStr, err := genRequestString(httpReq, reqBodyBytes) + if err != nil { + return nil, err } restyReq := t.svc.httpCli.NewRequest() restyReq.Header = httpReq.Header restyReq.Method = httpReq.Method restyReq.URL = httpReq.URL.String() - if len(reqBodyBytes) > 0 { + if reqBodyBytes != nil { restyReq.SetBody(reqBodyBytes) } restyReq.SetContext(ctx) @@ -559,6 +566,46 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res }, nil } +func genRequestString(req *http.Request, body []byte) (string, error) { + type Request struct { + Path string `json:"path"` + Header map[string]string `json:"header"` + Query map[string]string `json:"query"` + Body *[]byte `json:"body"` + } + + req_ := &Request{ + Path: req.URL.Path, + Header: map[string]string{}, + Query: map[string]string{}, + } + + if len(req.Header) > 0 { + for k, v := range req.Header { + req_.Header[k] = v[0] + } + } + if len(req.URL.Query()) > 0 { + for k, v := range req.URL.Query() { + req_.Query[k] = v[0] + } + } + + requestStr, err := sonic.MarshalString(req_) + if err != nil { + return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err) + } + + if body != nil { + requestStr, err = sjson.SetRaw(requestStr, "body", string(body)) + if err != nil { + return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err) + } + } + + return requestStr, nil +} + func (t *toolExecutor) preprocessArgumentsInJson(ctx context.Context, argumentsInJson string) (args map[string]any, err error) { args, err = t.prepareArguments(ctx, argumentsInJson) if err != nil { @@ -653,23 +700,13 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string] return nil, err } - reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL) + commonParams := t.plugin.Manifest.CommonParams + + reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL, commonParams) if err != nil { return nil, err } - httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), nil) - if err != nil { - return nil, err - } - - header, err := locArgs.buildHTTPRequestHeader(ctx) - if err != nil { - return nil, err - } - - httpReq.Header = header - bodyArgs := map[string]any{} for k, v := range argMaps { if _, ok := locArgs.header[k]; ok { @@ -684,13 +721,27 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string] bodyArgs[k] = v } - bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs) + commonBody := commonParams[model.ParamInBody] + bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs, commonBody) if err != nil { return nil, err } + + httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, err + } + + commonHeader := commonParams[model.ParamInHeader] + header, err := locArgs.buildHTTPRequestHeader(ctx, commonHeader) + if err != nil { + return nil, err + } + + httpReq.Header = header + if len(bodyBytes) > 0 { httpReq.Header.Set("Content-Type", contentType) - httpReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } return httpReq, nil @@ -698,13 +749,6 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string] func (t *toolExecutor) prepareArguments(_ context.Context, argumentsInJson string) (map[string]any, error) { args := map[string]any{} - for loc, params := range t.plugin.Manifest.CommonParams { - for _, p := range params { - if loc != model.ParamInBody { - args[p.Name] = p.Value - } - } - } decoder := sonic.ConfigDefault.NewDecoder(bytes.NewBufferString(argumentsInJson)) decoder.UseNumber() @@ -1175,7 +1219,9 @@ type valueWithSchema struct { paramSchema *openapi3.Parameter } -func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string) (reqURL *url.URL, err error) { +func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string, + commonParams map[model.HTTPParamLocation][]*common.CommonParamSchema) (reqURL *url.URL, err error) { + if len(l.path) > 0 { for k, v := range l.path { vStr, err := encoder.EncodeParameter(v.paramSchema, v.argValue) @@ -1186,9 +1232,8 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string } } - encodeQuery := "" + query := url.Values{} if len(l.query) > 0 { - query := url.Values{} for k, val := range l.query { switch v := val.argValue.(type) { case []any: @@ -1199,10 +1244,18 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string query.Add(k, encoder.MustString(v)) } } - - encodeQuery = query.Encode() } + commonQuery := commonParams[model.ParamInQuery] + for _, v := range commonQuery { + if _, ok := l.query[v.Name]; ok { + continue + } + query.Add(v.Name, v.Value) + } + + encodeQuery := query.Encode() + reqURL, err = url.Parse(rawURL) if err != nil { return nil, err @@ -1217,7 +1270,7 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string return reqURL, nil } -func (l *locationArguments) buildHTTPRequestHeader(_ context.Context) (http.Header, error) { +func (l *locationArguments) buildHTTPRequestHeader(_ context.Context, commonHeaders []*common.CommonParamSchema) (http.Header, error) { header := http.Header{} if len(l.header) > 0 { for k, v := range l.header { @@ -1232,44 +1285,64 @@ func (l *locationArguments) buildHTTPRequestHeader(_ context.Context) (http.Head } } + for _, h := range commonHeaders { + if header.Get(h.Name) != "" { + continue + } + header.Add(h.Name, h.Value) + } + return header, nil } -func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any) (body []byte, contentType string, err error) { +func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any, + commonBody []*common.CommonParamSchema) (body []byte, contentType string, err error) { + + var bodyMap map[string]any + contentType, bodySchema := t.getReqBodySchema(op) - if bodySchema == nil || bodySchema.Value == nil { - return nil, "", nil - } - - if len(bodySchema.Value.Properties) == 0 { - return nil, "", nil - } - - bodyMap, err := t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs) - if err != nil { - return nil, "", err - } - - for paramName, prop := range bodySchema.Value.Properties { - value, ok := bodyMap[paramName] - if !ok { - continue - } - - _value, err := encoder.TryFixValueType(paramName, prop, value) + if bodySchema != nil && len(bodySchema.Value.Properties) > 0 { + bodyMap, err = t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs) if err != nil { return nil, "", err } - bodyMap[paramName] = _value + for paramName, prop := range bodySchema.Value.Properties { + value, ok := bodyMap[paramName] + if !ok { + continue + } + + _value, err := encoder.TryFixValueType(paramName, prop, value) + if err != nil { + return nil, "", err + } + + bodyMap[paramName] = _value + } + + body, err = encoder.EncodeBodyWithContentType(contentType, bodyMap) + if err != nil { + return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err) + } } - reqBodyStr, err := encoder.EncodeBodyWithContentType(contentType, bodyMap) - if err != nil { - return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err) + commonBody_ := make([]*common.CommonParamSchema, 0, len(commonBody)) + for _, v := range commonBody { + if _, ok := bodyMap[v.Name]; ok { + continue + } + commonBody_ = append(commonBody_, v) } - return reqBodyStr, contentType, nil + for _, v := range commonBody_ { + body, err = sjson.SetRawBytes(body, v.Name, []byte(v.Value)) + if err != nil { + return nil, "", fmt.Errorf("[buildRequestBody] SetRawBytes failed, err=%v", err) + } + } + + return body, contentType, nil } func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *openapi3.Schema, vals map[string]any) (newVals map[string]any, err error) { @@ -1327,7 +1400,7 @@ func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *op } func (t *toolExecutor) getReqBodySchema(op *model.Openapi3Operation) (string, *openapi3.SchemaRef) { - if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 { + if op.RequestBody == nil || len(op.RequestBody.Value.Content) == 0 { return "", nil } diff --git a/backend/domain/plugin/service/exec_tool_test.go b/backend/domain/plugin/service/exec_tool_test.go new file mode 100644 index 00000000..1b1d28b7 --- /dev/null +++ b/backend/domain/plugin/service/exec_tool_test.go @@ -0,0 +1,49 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "net/http" + "net/url" + "testing" + + . "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" +) + +func TestGenRequestString(t *testing.T) { + PatchConvey("", t, func() { + requestStr, err := genRequestString(&http.Request{ + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Method: http.MethodPost, + URL: &url.URL{Path: "/test"}, + }, []byte(`{"a": 1}`)) + assert.NoError(t, err) + assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr) + }) + + PatchConvey("", t, func() { + var body []byte + requestStr, err := genRequestString(&http.Request{ + URL: &url.URL{Path: "/test"}, + }, body) + assert.NoError(t, err) + assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr) + }) +} diff --git a/backend/domain/plugin/service/plugin_draft.go b/backend/domain/plugin/service/plugin_draft.go index 5c56cf9b..3b13d1b7 100644 --- a/backend/domain/plugin/service/plugin_draft.go +++ b/backend/domain/plugin/service/plugin_draft.go @@ -46,6 +46,7 @@ import ( func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDraftPluginRequest) (pluginID int64, err error) { mf := entity.NewDefaultPluginManifest() + mf.CommonParams = map[model.HTTPParamLocation][]*plugin_develop_common.CommonParamSchema{} mf.NameForHuman = req.Name mf.NameForModel = req.Name mf.DescriptionForHuman = req.Desc @@ -65,11 +66,11 @@ func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDr return 0, fmt.Errorf("invalid location '%s'", loc.String()) } for _, param := range params { - mParams := mf.CommonParams[location] - mParams = append(mParams, &plugin_develop_common.CommonParamSchema{ - Name: param.Name, - Value: param.Value, - }) + mf.CommonParams[location] = append(mf.CommonParams[location], + &plugin_develop_common.CommonParamSchema{ + Name: param.Name, + Value: param.Value, + }) } } diff --git a/backend/domain/plugin/service/plugin_oauth.go b/backend/domain/plugin/service/plugin_oauth.go index b4ceae38..a29f3c02 100644 --- a/backend/domain/plugin/service/plugin_oauth.go +++ b/backend/domain/plugin/service/plugin_oauth.go @@ -194,7 +194,7 @@ func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Contex meta := ci.Meta info, exist, err := p.oauthRepo.GetAuthorizationCode(ctx, ci.Meta) if err != nil { - return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%p", + return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%t", meta.UserID, meta.PluginID, meta.IsDraft) } if !exist {