coze-studio/backend/domain/plugin/internal/encoder/req_encode.go

384 lines
9.4 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package encoder
import (
"encoding/json"
"fmt"
"net/url"
"strconv"
"github.com/bytedance/sonic"
"github.com/getkin/kin-openapi/openapi3"
"github.com/shopspring/decimal"
"gopkg.in/yaml.v3"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
)
func EncodeBodyWithContentType(contentType string, body map[string]any) ([]byte, error) {
switch contentType {
case plugin.MediaTypeJson, plugin.MediaTypeProblemJson:
return jsonBodyEncoder(body)
case plugin.MediaTypeFormURLEncoded:
return urlencodedBodyEncoder(body)
case plugin.MediaTypeYaml, plugin.MediaTypeXYaml:
return yamlBodyEncoder(body)
default:
return nil, fmt.Errorf("[EncodeBodyWithContentType] unsupported contentType=%s", contentType)
}
}
func jsonBodyEncoder(body map[string]any) ([]byte, error) {
b, err := sonic.Marshal(body)
if err != nil {
return nil, fmt.Errorf("[jsonBodyEncoder] failed to marshal body, err=%v", err)
}
return b, nil
}
func yamlBodyEncoder(body map[string]any) ([]byte, error) {
b, err := yaml.Marshal(body)
if err != nil {
return nil, fmt.Errorf("[yamlBodyEncoder] failed to marshal body, err=%v", err)
}
return b, nil
}
func urlencodedBodyEncoder(body map[string]any) ([]byte, error) {
objectStr := ""
res := url.Values{}
sm := &openapi3.SerializationMethod{
Style: openapi3.SerializationForm,
Explode: true,
}
for k, value := range body {
switch val := value.(type) {
case map[string]any:
vStr, err := encodeObjectParam(sm, k, val)
if err != nil {
return nil, err
}
if len(objectStr) > 0 {
vStr = "&" + vStr
}
objectStr += vStr
case []any:
vStr, err := encodeArrayParam(sm, k, val)
if err != nil {
return nil, err
}
if len(objectStr) > 0 {
vStr = "&" + vStr
}
objectStr += vStr
case string:
res.Add(k, val)
default:
res.Add(k, MustString(val))
}
}
if len(objectStr) > 0 {
return []byte(res.Encode() + "&" + url.QueryEscape(objectStr)), nil
}
return []byte(res.Encode()), nil
}
func EncodeParameter(param *openapi3.Parameter, value any) (string, error) {
sm, err := param.SerializationMethod()
if err != nil {
return "", err
}
switch v := value.(type) {
case map[string]any:
return encodeObjectParam(sm, param.Name, v)
case []any:
return encodeArrayParam(sm, param.Name, v)
default:
return encodePrimitiveParam(sm, param.Name, v)
}
}
func encodePrimitiveParam(sm *openapi3.SerializationMethod, paramName string, val any) (string, error) {
var prefix string
switch sm.Style {
case openapi3.SerializationSimple:
// A prefix is empty for style "simple".
case openapi3.SerializationLabel:
prefix = "."
case openapi3.SerializationMatrix:
prefix = ";" + url.QueryEscape(paramName) + "="
case openapi3.SerializationForm:
result := url.QueryEscape(paramName) + "=" + url.QueryEscape(MustString(val))
return result, nil
default:
return "", fmt.Errorf("invalid serialization method: style=%q, explode=%v", sm.Style, sm.Explode)
}
raw := MustString(val)
return prefix + raw, nil
}
func encodeArrayParam(sm *openapi3.SerializationMethod, paramName string, arrVal []any) (string, error) {
var prefix, delim string
switch {
case sm.Style == openapi3.SerializationMatrix && !sm.Explode:
prefix = ";" + paramName + "="
delim = ","
case sm.Style == openapi3.SerializationMatrix && sm.Explode:
prefix = ";" + paramName + "="
delim = ";" + paramName + "="
case sm.Style == openapi3.SerializationLabel && !sm.Explode:
prefix = "."
delim = ","
case sm.Style == openapi3.SerializationLabel && sm.Explode:
prefix = "."
delim = "."
case sm.Style == openapi3.SerializationForm && sm.Explode:
prefix = paramName + "="
delim = "&" + paramName + "="
case sm.Style == openapi3.SerializationForm && !sm.Explode:
prefix = paramName + "="
delim = ","
case sm.Style == openapi3.SerializationSimple:
delim = ","
case sm.Style == openapi3.SerializationSpaceDelimited && !sm.Explode:
delim = ","
case sm.Style == openapi3.SerializationPipeDelimited && !sm.Explode:
delim = "|"
default:
return "", fmt.Errorf("invalid serialization method: style=%q, explode=%v", sm.Style, sm.Explode)
}
res := prefix
for i, val := range arrVal {
vStr := MustString(val)
res += vStr
if i != len(arrVal)-1 {
res += delim
}
}
return res, nil
}
func encodeObjectParam(sm *openapi3.SerializationMethod, paramName string, mapVal map[string]any) (string, error) {
var prefix, propsDelim, valueDelim string
switch {
case sm.Style == openapi3.SerializationSimple && !sm.Explode:
propsDelim = ","
valueDelim = ","
case sm.Style == openapi3.SerializationSimple && sm.Explode:
propsDelim = ","
valueDelim = "="
case sm.Style == openapi3.SerializationLabel && !sm.Explode:
prefix = "."
propsDelim = "."
valueDelim = "."
case sm.Style == openapi3.SerializationLabel && sm.Explode:
prefix = "."
propsDelim = "."
valueDelim = "="
case sm.Style == openapi3.SerializationMatrix && !sm.Explode:
prefix = ";" + paramName + "="
propsDelim = ","
valueDelim = ","
case sm.Style == openapi3.SerializationMatrix && sm.Explode:
prefix = ";"
propsDelim = ";"
valueDelim = "="
case sm.Style == openapi3.SerializationForm && !sm.Explode:
prefix = paramName + "="
propsDelim = ","
valueDelim = ","
case sm.Style == openapi3.SerializationForm && sm.Explode:
propsDelim = "&"
valueDelim = "="
case sm.Style == openapi3.SerializationSpaceDelimited && !sm.Explode:
propsDelim = " "
valueDelim = " "
case sm.Style == openapi3.SerializationPipeDelimited && !sm.Explode:
propsDelim = "|"
valueDelim = "|"
case sm.Style == openapi3.SerializationDeepObject && sm.Explode:
prefix = paramName + "["
propsDelim = "&color["
valueDelim = "]="
default:
return "", fmt.Errorf("invalid serialization method: style=%s, explode=%t", sm.Style, sm.Explode)
}
res := prefix
for k, val := range mapVal {
vStr := MustString(val)
res += k + valueDelim + vStr + propsDelim
}
if len(mapVal) > 0 && len(res) > 0 {
res = res[:len(res)-1]
}
return res, nil
}
func MustString(value any) string {
if value == nil {
return ""
}
switch val := value.(type) {
case string:
return val
default:
b, _ := json.Marshal(val)
return string(b)
}
}
func TryCorrectValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
if value == nil {
return "", fmt.Errorf("value of '%s' is nil", paramName)
}
switch schemaRef.Value.Type {
case openapi3.TypeString:
return tryCorrectString(value)
case openapi3.TypeNumber:
return tryCorrectFloat64(value)
case openapi3.TypeInteger:
return tryCorrectInt64(value)
case openapi3.TypeBoolean:
return tryCorrectBool(value)
case openapi3.TypeArray:
arrVal, ok := value.([]any)
if !ok {
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not array", paramName)
}
for i, v := range arrVal {
_v, err := TryCorrectValueType(paramName, schemaRef.Value.Items, v)
if err != nil {
return nil, err
}
arrVal[i] = _v
}
return arrVal, nil
case openapi3.TypeObject:
mapVal, ok := value.(map[string]any)
if !ok {
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not object", paramName)
}
for k, v := range mapVal {
p, ok := schemaRef.Value.Properties[k]
if !ok {
continue
}
_v, err := TryCorrectValueType(k, p, v)
if err != nil {
return nil, err
}
mapVal[k] = _v
}
return mapVal, nil
default:
return nil, fmt.Errorf("[TryCorrectValueType] unsupported schema type '%s'", schemaRef.Value.Type)
}
}
func tryCorrectString(value any) (string, error) {
switch val := value.(type) {
case string:
return val, nil
case int64:
return strconv.FormatInt(val, 10), nil
case float64:
d := decimal.NewFromFloat(val)
return d.String(), nil
case json.Number:
return val.String(), nil
default:
b, err := sonic.MarshalString(value)
if err != nil {
return "", fmt.Errorf("tryCorrectString failed, err=%w", err)
}
return b, nil
}
}
func tryCorrectInt64(value any) (int64, error) {
switch val := value.(type) {
case string:
vi64, _ := strconv.ParseInt(val, 10, 64)
return vi64, nil
case int64:
return val, nil
case float64:
return int64(val), nil
case json.Number:
vi64, _ := strconv.ParseInt(val.String(), 10, 64)
return vi64, nil
default:
return 0, fmt.Errorf("cannot convert type from '%T' to int64", val)
}
}
func tryCorrectBool(value any) (bool, error) {
switch val := value.(type) {
case string:
return strconv.ParseBool(val)
case bool:
return val, nil
default:
return false, fmt.Errorf("cannot convert type from '%T' to bool", val)
}
}
func tryCorrectFloat64(value any) (float64, error) {
switch val := value.(type) {
case string:
return strconv.ParseFloat(val, 64)
case float64:
return val, nil
case int64:
return float64(val), nil
case json.Number:
return strconv.ParseFloat(val.String(), 64)
default:
return 0, fmt.Errorf("cannot convert type from '%T' to float64", val)
}
}