384 lines
9.4 KiB
Go
384 lines
9.4 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package encoder
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"strconv"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/getkin/kin-openapi/openapi3"
|
|
"github.com/shopspring/decimal"
|
|
"gopkg.in/yaml.v3"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
|
)
|
|
|
|
func EncodeBodyWithContentType(contentType string, body map[string]any) ([]byte, error) {
|
|
switch contentType {
|
|
case plugin.MediaTypeJson, plugin.MediaTypeProblemJson:
|
|
return jsonBodyEncoder(body)
|
|
case plugin.MediaTypeFormURLEncoded:
|
|
return urlencodedBodyEncoder(body)
|
|
case plugin.MediaTypeYaml, plugin.MediaTypeXYaml:
|
|
return yamlBodyEncoder(body)
|
|
default:
|
|
return nil, fmt.Errorf("[EncodeBodyWithContentType] unsupported contentType=%s", contentType)
|
|
}
|
|
}
|
|
|
|
func jsonBodyEncoder(body map[string]any) ([]byte, error) {
|
|
b, err := sonic.Marshal(body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("[jsonBodyEncoder] failed to marshal body, err=%v", err)
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
func yamlBodyEncoder(body map[string]any) ([]byte, error) {
|
|
b, err := yaml.Marshal(body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("[yamlBodyEncoder] failed to marshal body, err=%v", err)
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
func urlencodedBodyEncoder(body map[string]any) ([]byte, error) {
|
|
objectStr := ""
|
|
res := url.Values{}
|
|
sm := &openapi3.SerializationMethod{
|
|
Style: openapi3.SerializationForm,
|
|
Explode: true,
|
|
}
|
|
|
|
for k, value := range body {
|
|
switch val := value.(type) {
|
|
case map[string]any:
|
|
vStr, err := encodeObjectParam(sm, k, val)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(objectStr) > 0 {
|
|
vStr = "&" + vStr
|
|
}
|
|
|
|
objectStr += vStr
|
|
case []any:
|
|
vStr, err := encodeArrayParam(sm, k, val)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(objectStr) > 0 {
|
|
vStr = "&" + vStr
|
|
}
|
|
|
|
objectStr += vStr
|
|
case string:
|
|
res.Add(k, val)
|
|
default:
|
|
res.Add(k, MustString(val))
|
|
}
|
|
}
|
|
|
|
if len(objectStr) > 0 {
|
|
return []byte(res.Encode() + "&" + url.QueryEscape(objectStr)), nil
|
|
}
|
|
|
|
return []byte(res.Encode()), nil
|
|
}
|
|
|
|
func EncodeParameter(param *openapi3.Parameter, value any) (string, error) {
|
|
sm, err := param.SerializationMethod()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch v := value.(type) {
|
|
case map[string]any:
|
|
return encodeObjectParam(sm, param.Name, v)
|
|
case []any:
|
|
return encodeArrayParam(sm, param.Name, v)
|
|
default:
|
|
return encodePrimitiveParam(sm, param.Name, v)
|
|
}
|
|
}
|
|
|
|
func encodePrimitiveParam(sm *openapi3.SerializationMethod, paramName string, val any) (string, error) {
|
|
var prefix string
|
|
switch sm.Style {
|
|
case openapi3.SerializationSimple:
|
|
// A prefix is empty for style "simple".
|
|
case openapi3.SerializationLabel:
|
|
prefix = "."
|
|
case openapi3.SerializationMatrix:
|
|
prefix = ";" + url.QueryEscape(paramName) + "="
|
|
case openapi3.SerializationForm:
|
|
result := url.QueryEscape(paramName) + "=" + url.QueryEscape(MustString(val))
|
|
return result, nil
|
|
default:
|
|
return "", fmt.Errorf("invalid serialization method: style=%q, explode=%v", sm.Style, sm.Explode)
|
|
}
|
|
|
|
raw := MustString(val)
|
|
|
|
return prefix + raw, nil
|
|
}
|
|
|
|
func encodeArrayParam(sm *openapi3.SerializationMethod, paramName string, arrVal []any) (string, error) {
|
|
var prefix, delim string
|
|
switch {
|
|
case sm.Style == openapi3.SerializationMatrix && !sm.Explode:
|
|
prefix = ";" + paramName + "="
|
|
delim = ","
|
|
case sm.Style == openapi3.SerializationMatrix && sm.Explode:
|
|
prefix = ";" + paramName + "="
|
|
delim = ";" + paramName + "="
|
|
case sm.Style == openapi3.SerializationLabel && !sm.Explode:
|
|
prefix = "."
|
|
delim = ","
|
|
case sm.Style == openapi3.SerializationLabel && sm.Explode:
|
|
prefix = "."
|
|
delim = "."
|
|
case sm.Style == openapi3.SerializationForm && sm.Explode:
|
|
prefix = paramName + "="
|
|
delim = "&" + paramName + "="
|
|
case sm.Style == openapi3.SerializationForm && !sm.Explode:
|
|
prefix = paramName + "="
|
|
delim = ","
|
|
case sm.Style == openapi3.SerializationSimple:
|
|
delim = ","
|
|
case sm.Style == openapi3.SerializationSpaceDelimited && !sm.Explode:
|
|
delim = ","
|
|
case sm.Style == openapi3.SerializationPipeDelimited && !sm.Explode:
|
|
delim = "|"
|
|
default:
|
|
return "", fmt.Errorf("invalid serialization method: style=%q, explode=%v", sm.Style, sm.Explode)
|
|
}
|
|
|
|
res := prefix
|
|
|
|
for i, val := range arrVal {
|
|
vStr := MustString(val)
|
|
res += vStr
|
|
|
|
if i != len(arrVal)-1 {
|
|
res += delim
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func encodeObjectParam(sm *openapi3.SerializationMethod, paramName string, mapVal map[string]any) (string, error) {
|
|
var prefix, propsDelim, valueDelim string
|
|
|
|
switch {
|
|
case sm.Style == openapi3.SerializationSimple && !sm.Explode:
|
|
propsDelim = ","
|
|
valueDelim = ","
|
|
case sm.Style == openapi3.SerializationSimple && sm.Explode:
|
|
propsDelim = ","
|
|
valueDelim = "="
|
|
case sm.Style == openapi3.SerializationLabel && !sm.Explode:
|
|
prefix = "."
|
|
propsDelim = "."
|
|
valueDelim = "."
|
|
case sm.Style == openapi3.SerializationLabel && sm.Explode:
|
|
prefix = "."
|
|
propsDelim = "."
|
|
valueDelim = "="
|
|
case sm.Style == openapi3.SerializationMatrix && !sm.Explode:
|
|
prefix = ";" + paramName + "="
|
|
propsDelim = ","
|
|
valueDelim = ","
|
|
case sm.Style == openapi3.SerializationMatrix && sm.Explode:
|
|
prefix = ";"
|
|
propsDelim = ";"
|
|
valueDelim = "="
|
|
case sm.Style == openapi3.SerializationForm && !sm.Explode:
|
|
prefix = paramName + "="
|
|
propsDelim = ","
|
|
valueDelim = ","
|
|
case sm.Style == openapi3.SerializationForm && sm.Explode:
|
|
propsDelim = "&"
|
|
valueDelim = "="
|
|
case sm.Style == openapi3.SerializationSpaceDelimited && !sm.Explode:
|
|
propsDelim = " "
|
|
valueDelim = " "
|
|
case sm.Style == openapi3.SerializationPipeDelimited && !sm.Explode:
|
|
propsDelim = "|"
|
|
valueDelim = "|"
|
|
case sm.Style == openapi3.SerializationDeepObject && sm.Explode:
|
|
prefix = paramName + "["
|
|
propsDelim = "&color["
|
|
valueDelim = "]="
|
|
default:
|
|
return "", fmt.Errorf("invalid serialization method: style=%s, explode=%t", sm.Style, sm.Explode)
|
|
}
|
|
|
|
res := prefix
|
|
for k, val := range mapVal {
|
|
vStr := MustString(val)
|
|
res += k + valueDelim + vStr + propsDelim
|
|
}
|
|
|
|
if len(mapVal) > 0 && len(res) > 0 {
|
|
res = res[:len(res)-1]
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func MustString(value any) string {
|
|
if value == nil {
|
|
return ""
|
|
}
|
|
|
|
switch val := value.(type) {
|
|
case string:
|
|
return val
|
|
default:
|
|
b, _ := json.Marshal(val)
|
|
return string(b)
|
|
}
|
|
}
|
|
|
|
func TryCorrectValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
|
|
if value == nil {
|
|
return "", fmt.Errorf("value of '%s' is nil", paramName)
|
|
}
|
|
|
|
switch schemaRef.Value.Type {
|
|
case openapi3.TypeString:
|
|
return tryCorrectString(value)
|
|
case openapi3.TypeNumber:
|
|
return tryCorrectFloat64(value)
|
|
case openapi3.TypeInteger:
|
|
return tryCorrectInt64(value)
|
|
case openapi3.TypeBoolean:
|
|
return tryCorrectBool(value)
|
|
case openapi3.TypeArray:
|
|
arrVal, ok := value.([]any)
|
|
if !ok {
|
|
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not array", paramName)
|
|
}
|
|
|
|
for i, v := range arrVal {
|
|
_v, err := TryCorrectValueType(paramName, schemaRef.Value.Items, v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
arrVal[i] = _v
|
|
}
|
|
|
|
return arrVal, nil
|
|
case openapi3.TypeObject:
|
|
mapVal, ok := value.(map[string]any)
|
|
if !ok {
|
|
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not object", paramName)
|
|
}
|
|
|
|
for k, v := range mapVal {
|
|
p, ok := schemaRef.Value.Properties[k]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
_v, err := TryCorrectValueType(k, p, v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mapVal[k] = _v
|
|
}
|
|
|
|
return mapVal, nil
|
|
default:
|
|
return nil, fmt.Errorf("[TryCorrectValueType] unsupported schema type '%s'", schemaRef.Value.Type)
|
|
}
|
|
}
|
|
|
|
func tryCorrectString(value any) (string, error) {
|
|
switch val := value.(type) {
|
|
case string:
|
|
return val, nil
|
|
case int64:
|
|
return strconv.FormatInt(val, 10), nil
|
|
case float64:
|
|
d := decimal.NewFromFloat(val)
|
|
return d.String(), nil
|
|
case json.Number:
|
|
return val.String(), nil
|
|
default:
|
|
b, err := sonic.MarshalString(value)
|
|
if err != nil {
|
|
return "", fmt.Errorf("tryCorrectString failed, err=%w", err)
|
|
}
|
|
return b, nil
|
|
}
|
|
}
|
|
|
|
func tryCorrectInt64(value any) (int64, error) {
|
|
switch val := value.(type) {
|
|
case string:
|
|
vi64, _ := strconv.ParseInt(val, 10, 64)
|
|
return vi64, nil
|
|
case int64:
|
|
return val, nil
|
|
case float64:
|
|
return int64(val), nil
|
|
case json.Number:
|
|
vi64, _ := strconv.ParseInt(val.String(), 10, 64)
|
|
return vi64, nil
|
|
default:
|
|
return 0, fmt.Errorf("cannot convert type from '%T' to int64", val)
|
|
}
|
|
}
|
|
|
|
func tryCorrectBool(value any) (bool, error) {
|
|
switch val := value.(type) {
|
|
case string:
|
|
return strconv.ParseBool(val)
|
|
case bool:
|
|
return val, nil
|
|
default:
|
|
return false, fmt.Errorf("cannot convert type from '%T' to bool", val)
|
|
}
|
|
}
|
|
|
|
func tryCorrectFloat64(value any) (float64, error) {
|
|
switch val := value.(type) {
|
|
case string:
|
|
return strconv.ParseFloat(val, 64)
|
|
case float64:
|
|
return val, nil
|
|
case int64:
|
|
return float64(val), nil
|
|
case json.Number:
|
|
return strconv.ParseFloat(val.String(), 64)
|
|
default:
|
|
return 0, fmt.Errorf("cannot convert type from '%T' to float64", val)
|
|
}
|
|
}
|