380 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			380 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package encoder
 | 
						|
 | 
						|
import (
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"net/url"
 | 
						|
	"strconv"
 | 
						|
 | 
						|
	"github.com/bytedance/sonic"
 | 
						|
	"github.com/getkin/kin-openapi/openapi3"
 | 
						|
	"github.com/shopspring/decimal"
 | 
						|
	"gopkg.in/yaml.v3"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
 | 
						|
)
 | 
						|
 | 
						|
func EncodeBodyWithContentType(contentType string, body map[string]any) ([]byte, error) {
 | 
						|
	switch contentType {
 | 
						|
	case plugin.MediaTypeJson, plugin.MediaTypeProblemJson:
 | 
						|
		return jsonBodyEncoder(body)
 | 
						|
	case plugin.MediaTypeFormURLEncoded:
 | 
						|
		return urlencodedBodyEncoder(body)
 | 
						|
	case plugin.MediaTypeYaml, plugin.MediaTypeXYaml:
 | 
						|
		return yamlBodyEncoder(body)
 | 
						|
	default:
 | 
						|
		return nil, fmt.Errorf("[EncodeBodyWithContentType] unsupported contentType=%s", contentType)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func jsonBodyEncoder(body map[string]any) ([]byte, error) {
 | 
						|
	b, err := sonic.Marshal(body)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("[jsonBodyEncoder] failed to marshal body, err=%v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return b, nil
 | 
						|
}
 | 
						|
 | 
						|
func yamlBodyEncoder(body map[string]any) ([]byte, error) {
 | 
						|
	b, err := yaml.Marshal(body)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("[yamlBodyEncoder] failed to marshal body, err=%v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return b, nil
 | 
						|
}
 | 
						|
 | 
						|
func urlencodedBodyEncoder(body map[string]any) ([]byte, error) {
 | 
						|
	objectStr := ""
 | 
						|
	res := url.Values{}
 | 
						|
	sm := &openapi3.SerializationMethod{
 | 
						|
		Style:   openapi3.SerializationForm,
 | 
						|
		Explode: true,
 | 
						|
	}
 | 
						|
 | 
						|
	for k, value := range body {
 | 
						|
		switch val := value.(type) {
 | 
						|
		case map[string]any:
 | 
						|
			vStr, err := encodeObjectParam(sm, k, val)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			if len(objectStr) > 0 {
 | 
						|
				vStr = "&" + vStr
 | 
						|
			}
 | 
						|
 | 
						|
			objectStr += vStr
 | 
						|
		case []any:
 | 
						|
			vStr, err := encodeArrayParam(sm, k, val)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			if len(objectStr) > 0 {
 | 
						|
				vStr = "&" + vStr
 | 
						|
			}
 | 
						|
 | 
						|
			objectStr += vStr
 | 
						|
		case string:
 | 
						|
			res.Add(k, val)
 | 
						|
		default:
 | 
						|
			res.Add(k, MustString(val))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if len(objectStr) > 0 {
 | 
						|
		return []byte(res.Encode() + "&" + url.QueryEscape(objectStr)), nil
 | 
						|
	}
 | 
						|
 | 
						|
	return []byte(res.Encode()), nil
 | 
						|
}
 | 
						|
 | 
						|
func EncodeParameter(param *openapi3.Parameter, value any) (string, error) {
 | 
						|
	sm, err := param.SerializationMethod()
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	switch v := value.(type) {
 | 
						|
	case map[string]any:
 | 
						|
		return encodeObjectParam(sm, param.Name, v)
 | 
						|
	case []any:
 | 
						|
		return encodeArrayParam(sm, param.Name, v)
 | 
						|
	default:
 | 
						|
		return encodePrimitiveParam(sm, param.Name, v)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func encodePrimitiveParam(sm *openapi3.SerializationMethod, paramName string, val any) (string, error) {
 | 
						|
	var prefix string
 | 
						|
	switch sm.Style {
 | 
						|
	case openapi3.SerializationSimple:
 | 
						|
		// A prefix is empty for style "simple".
 | 
						|
	case openapi3.SerializationLabel:
 | 
						|
		prefix = "."
 | 
						|
	case openapi3.SerializationMatrix:
 | 
						|
		prefix = ";" + url.QueryEscape(paramName) + "="
 | 
						|
	case openapi3.SerializationForm:
 | 
						|
		result := url.QueryEscape(paramName) + "=" + url.QueryEscape(MustString(val))
 | 
						|
		return result, nil
 | 
						|
	default:
 | 
						|
		return "", fmt.Errorf("invalid serialization method: style=%q, explode=%v", sm.Style, sm.Explode)
 | 
						|
	}
 | 
						|
 | 
						|
	raw := MustString(val)
 | 
						|
 | 
						|
	return prefix + raw, nil
 | 
						|
}
 | 
						|
 | 
						|
func encodeArrayParam(sm *openapi3.SerializationMethod, paramName string, arrVal []any) (string, error) {
 | 
						|
	var prefix, delim string
 | 
						|
	switch {
 | 
						|
	case sm.Style == openapi3.SerializationMatrix && !sm.Explode:
 | 
						|
		prefix = ";" + paramName + "="
 | 
						|
		delim = ","
 | 
						|
	case sm.Style == openapi3.SerializationMatrix && sm.Explode:
 | 
						|
		prefix = ";" + paramName + "="
 | 
						|
		delim = ";" + paramName + "="
 | 
						|
	case sm.Style == openapi3.SerializationLabel && !sm.Explode:
 | 
						|
		prefix = "."
 | 
						|
		delim = ","
 | 
						|
	case sm.Style == openapi3.SerializationLabel && sm.Explode:
 | 
						|
		prefix = "."
 | 
						|
		delim = "."
 | 
						|
	case sm.Style == openapi3.SerializationForm && sm.Explode:
 | 
						|
		prefix = paramName + "="
 | 
						|
		delim = "&" + paramName + "="
 | 
						|
	case sm.Style == openapi3.SerializationForm && !sm.Explode:
 | 
						|
		prefix = paramName + "="
 | 
						|
		delim = ","
 | 
						|
	case sm.Style == openapi3.SerializationSimple:
 | 
						|
		delim = ","
 | 
						|
	case sm.Style == openapi3.SerializationSpaceDelimited && !sm.Explode:
 | 
						|
		delim = ","
 | 
						|
	case sm.Style == openapi3.SerializationPipeDelimited && !sm.Explode:
 | 
						|
		delim = "|"
 | 
						|
	default:
 | 
						|
		return "", fmt.Errorf("invalid serialization method: style=%q, explode=%v", sm.Style, sm.Explode)
 | 
						|
	}
 | 
						|
 | 
						|
	res := prefix
 | 
						|
 | 
						|
	for i, val := range arrVal {
 | 
						|
		vStr := MustString(val)
 | 
						|
		res += vStr
 | 
						|
 | 
						|
		if i != len(arrVal)-1 {
 | 
						|
			res += delim
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return res, nil
 | 
						|
}
 | 
						|
 | 
						|
func encodeObjectParam(sm *openapi3.SerializationMethod, paramName string, mapVal map[string]any) (string, error) {
 | 
						|
	var prefix, propsDelim, valueDelim string
 | 
						|
 | 
						|
	switch {
 | 
						|
	case sm.Style == openapi3.SerializationSimple && !sm.Explode:
 | 
						|
		propsDelim = ","
 | 
						|
		valueDelim = ","
 | 
						|
	case sm.Style == openapi3.SerializationSimple && sm.Explode:
 | 
						|
		propsDelim = ","
 | 
						|
		valueDelim = "="
 | 
						|
	case sm.Style == openapi3.SerializationLabel && !sm.Explode:
 | 
						|
		prefix = "."
 | 
						|
		propsDelim = "."
 | 
						|
		valueDelim = "."
 | 
						|
	case sm.Style == openapi3.SerializationLabel && sm.Explode:
 | 
						|
		prefix = "."
 | 
						|
		propsDelim = "."
 | 
						|
		valueDelim = "="
 | 
						|
	case sm.Style == openapi3.SerializationMatrix && !sm.Explode:
 | 
						|
		prefix = ";" + paramName + "="
 | 
						|
		propsDelim = ","
 | 
						|
		valueDelim = ","
 | 
						|
	case sm.Style == openapi3.SerializationMatrix && sm.Explode:
 | 
						|
		prefix = ";"
 | 
						|
		propsDelim = ";"
 | 
						|
		valueDelim = "="
 | 
						|
	case sm.Style == openapi3.SerializationForm && !sm.Explode:
 | 
						|
		prefix = paramName + "="
 | 
						|
		propsDelim = ","
 | 
						|
		valueDelim = ","
 | 
						|
	case sm.Style == openapi3.SerializationForm && sm.Explode:
 | 
						|
		propsDelim = "&"
 | 
						|
		valueDelim = "="
 | 
						|
	case sm.Style == openapi3.SerializationSpaceDelimited && !sm.Explode:
 | 
						|
		propsDelim = " "
 | 
						|
		valueDelim = " "
 | 
						|
	case sm.Style == openapi3.SerializationPipeDelimited && !sm.Explode:
 | 
						|
		propsDelim = "|"
 | 
						|
		valueDelim = "|"
 | 
						|
	case sm.Style == openapi3.SerializationDeepObject && sm.Explode:
 | 
						|
		prefix = paramName + "["
 | 
						|
		propsDelim = "&color["
 | 
						|
		valueDelim = "]="
 | 
						|
	default:
 | 
						|
		return "", fmt.Errorf("invalid serialization method: style=%s, explode=%t", sm.Style, sm.Explode)
 | 
						|
	}
 | 
						|
 | 
						|
	res := prefix
 | 
						|
	for k, val := range mapVal {
 | 
						|
		vStr := MustString(val)
 | 
						|
		res += k + valueDelim + vStr + propsDelim
 | 
						|
	}
 | 
						|
 | 
						|
	if len(mapVal) > 0 && len(res) > 0 {
 | 
						|
		res = res[:len(res)-1]
 | 
						|
	}
 | 
						|
 | 
						|
	return res, nil
 | 
						|
}
 | 
						|
 | 
						|
func MustString(value any) string {
 | 
						|
	if value == nil {
 | 
						|
		return ""
 | 
						|
	}
 | 
						|
 | 
						|
	switch val := value.(type) {
 | 
						|
	case string:
 | 
						|
		return val
 | 
						|
	default:
 | 
						|
		b, _ := json.Marshal(val)
 | 
						|
		return string(b)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
 | 
						|
	if value == nil {
 | 
						|
		return "", fmt.Errorf("value of '%s' is nil", paramName)
 | 
						|
	}
 | 
						|
 | 
						|
	switch schemaRef.Value.Type {
 | 
						|
	case openapi3.TypeString:
 | 
						|
		return tryString(value)
 | 
						|
	case openapi3.TypeNumber:
 | 
						|
		return tryFloat64(value)
 | 
						|
	case openapi3.TypeInteger:
 | 
						|
		return tryInt64(value)
 | 
						|
	case openapi3.TypeBoolean:
 | 
						|
		return tryBool(value)
 | 
						|
	case openapi3.TypeArray:
 | 
						|
		arrVal, ok := value.([]any)
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("[TryFixValueType] value '%s' is not array", paramName)
 | 
						|
		}
 | 
						|
 | 
						|
		for i, v := range arrVal {
 | 
						|
			_v, err := TryFixValueType(paramName, schemaRef.Value.Items, v)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			arrVal[i] = _v
 | 
						|
		}
 | 
						|
 | 
						|
		return arrVal, nil
 | 
						|
	case openapi3.TypeObject:
 | 
						|
		mapVal, ok := value.(map[string]any)
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("[TryFixValueType] value '%s' is not object", paramName)
 | 
						|
		}
 | 
						|
 | 
						|
		for k, v := range mapVal {
 | 
						|
			p, ok := schemaRef.Value.Properties[k]
 | 
						|
			if !ok {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			_v, err := TryFixValueType(k, p, v)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			mapVal[k] = _v
 | 
						|
		}
 | 
						|
 | 
						|
		return mapVal, nil
 | 
						|
	default:
 | 
						|
		return nil, fmt.Errorf("[TryFixValueType] unsupported schema type '%s'", schemaRef.Value.Type)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func tryString(value any) (string, error) {
 | 
						|
	switch val := value.(type) {
 | 
						|
	case string:
 | 
						|
		return val, nil
 | 
						|
	case int64:
 | 
						|
		return strconv.FormatInt(val, 10), nil
 | 
						|
	case float64:
 | 
						|
		d := decimal.NewFromFloat(val)
 | 
						|
		return d.String(), nil
 | 
						|
	case json.Number:
 | 
						|
		return val.String(), nil
 | 
						|
	default:
 | 
						|
		return "", fmt.Errorf("cannot convert type from '%T' to string", val)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func tryInt64(value any) (int64, error) {
 | 
						|
	switch val := value.(type) {
 | 
						|
	case string:
 | 
						|
		vi64, _ := strconv.ParseInt(val, 10, 64)
 | 
						|
		return vi64, nil
 | 
						|
	case int64:
 | 
						|
		return val, nil
 | 
						|
	case float64:
 | 
						|
		return int64(val), nil
 | 
						|
	case json.Number:
 | 
						|
		vi64, _ := strconv.ParseInt(val.String(), 10, 64)
 | 
						|
		return vi64, nil
 | 
						|
	default:
 | 
						|
		return 0, fmt.Errorf("cannot convert type from '%T' to int64", val)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func tryBool(value any) (bool, error) {
 | 
						|
	switch val := value.(type) {
 | 
						|
	case string:
 | 
						|
		return strconv.ParseBool(val)
 | 
						|
	case bool:
 | 
						|
		return val, nil
 | 
						|
	default:
 | 
						|
		return false, fmt.Errorf("cannot convert type from '%T' to bool", val)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func tryFloat64(value any) (float64, error) {
 | 
						|
	switch val := value.(type) {
 | 
						|
	case string:
 | 
						|
		return strconv.ParseFloat(val, 64)
 | 
						|
	case float64:
 | 
						|
		return val, nil
 | 
						|
	case int64:
 | 
						|
		return float64(val), nil
 | 
						|
	case json.Number:
 | 
						|
		return strconv.ParseFloat(val.String(), 64)
 | 
						|
	default:
 | 
						|
		return 0, fmt.Errorf("cannot convert type from '%T' to float64", val)
 | 
						|
	}
 | 
						|
}
 |