162 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package vikingdb
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"sort"
 | 
						|
 | 
						|
	"github.com/volcengine/volc-sdk-golang/base"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
)
 | 
						|
 | 
						|
type Config struct {
 | 
						|
	AK string
 | 
						|
	SK string
 | 
						|
 | 
						|
	Region string // default cn-north-1
 | 
						|
}
 | 
						|
 | 
						|
func NewReranker(config *Config) rerank.Reranker {
 | 
						|
	if config.Region == "" {
 | 
						|
		config.Region = "cn-north-1"
 | 
						|
	}
 | 
						|
	return &reranker{config: config}
 | 
						|
}
 | 
						|
 | 
						|
const (
 | 
						|
	domain       = "api-knowledgebase.mlp.cn-beijing.volces.com"
 | 
						|
	defaultModel = "base-multilingual-rerank"
 | 
						|
)
 | 
						|
 | 
						|
type reranker struct {
 | 
						|
	config *Config
 | 
						|
}
 | 
						|
 | 
						|
type rerankReq struct {
 | 
						|
	Datas       []rerankData `json:"datas"`
 | 
						|
	RerankModel string       `json:"rerank_model"`
 | 
						|
}
 | 
						|
 | 
						|
type rerankData struct {
 | 
						|
	Query   string  `json:"query"`
 | 
						|
	Content string  `json:"content"`
 | 
						|
	Title   *string `json:"title,omitempty"`
 | 
						|
}
 | 
						|
 | 
						|
type rerankResp struct {
 | 
						|
	Code    int64  `json:"code"`
 | 
						|
	Message string `json:"message"`
 | 
						|
	Data    struct {
 | 
						|
		Scores     []float64 `json:"scores"`
 | 
						|
		TokenUsage int64     `json:"token_usage"`
 | 
						|
	} `json:"data"`
 | 
						|
}
 | 
						|
 | 
						|
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
 | 
						|
	rReq := &rerankReq{
 | 
						|
		Datas:       make([]rerankData, 0, len(req.Data)),
 | 
						|
		RerankModel: defaultModel,
 | 
						|
	}
 | 
						|
 | 
						|
	var flat []*rerank.Data
 | 
						|
	for _, channel := range req.Data {
 | 
						|
		flat = append(flat, channel...)
 | 
						|
	}
 | 
						|
 | 
						|
	for _, item := range flat {
 | 
						|
		rReq.Datas = append(rReq.Datas, rerankData{
 | 
						|
			Query:   req.Query,
 | 
						|
			Content: item.Document.Content,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	body, err := json.Marshal(rReq)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	resp, err := http.DefaultClient.Do(r.prepareRequest(body))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	defer resp.Body.Close()
 | 
						|
	respBody, err := io.ReadAll(resp.Body)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	rResp := rerankResp{}
 | 
						|
	if err = json.Unmarshal(respBody, &rResp); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	if rResp.Code != 0 {
 | 
						|
		return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
 | 
						|
	}
 | 
						|
 | 
						|
	sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
 | 
						|
	for i, score := range rResp.Data.Scores {
 | 
						|
		sorted = append(sorted, &rerank.Data{
 | 
						|
			Document: flat[i].Document,
 | 
						|
			Score:    score,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(sorted, func(i, j int) bool {
 | 
						|
		return sorted[i].Score > sorted[j].Score
 | 
						|
	})
 | 
						|
 | 
						|
	right := len(sorted)
 | 
						|
	if req.TopN != nil {
 | 
						|
		right = min(right, int(*req.TopN))
 | 
						|
	}
 | 
						|
 | 
						|
	return &rerank.Response{
 | 
						|
		SortedData: sorted[:right],
 | 
						|
		TokenUsage: ptr.Of(rResp.Data.TokenUsage),
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (r *reranker) prepareRequest(body []byte) *http.Request {
 | 
						|
	u := url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   domain,
 | 
						|
		Path:   "/api/knowledge/service/rerank",
 | 
						|
	}
 | 
						|
	req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
 | 
						|
	req.Header.Add("Accept", "application/json")
 | 
						|
	req.Header.Add("Content-Type", "application/json")
 | 
						|
	req.Header.Add("Host", domain)
 | 
						|
	credential := base.Credentials{
 | 
						|
		AccessKeyID:     r.config.AK,
 | 
						|
		SecretAccessKey: r.config.SK,
 | 
						|
		Service:         "air",
 | 
						|
		Region:          r.config.Region,
 | 
						|
	}
 | 
						|
	req = credential.Sign(req)
 | 
						|
	return req
 | 
						|
}
 |