188 lines
4.2 KiB
Go
188 lines
4.2 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package vikingdb
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"sort"
|
|
|
|
"github.com/volcengine/volc-sdk-golang/base"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
)
|
|
|
|
type Config struct {
|
|
AK string
|
|
SK string
|
|
Domain string
|
|
Model string
|
|
Region string // default cn-north-1
|
|
}
|
|
|
|
func NewReranker(config *Config) rerank.Reranker {
|
|
if config.Region == "" {
|
|
config.Region = "cn-north-1"
|
|
}
|
|
if config.Domain == "" {
|
|
config.Domain = domain
|
|
}
|
|
if config.Model == "" {
|
|
config.Model = defaultModel
|
|
}
|
|
return &reranker{config: config}
|
|
}
|
|
|
|
const (
|
|
domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
|
|
defaultModel = "base-multilingual-rerank"
|
|
)
|
|
|
|
type reranker struct {
|
|
config *Config
|
|
}
|
|
|
|
type rerankReq struct {
|
|
Datas []rerankData `json:"datas"`
|
|
RerankModel string `json:"rerank_model"`
|
|
}
|
|
|
|
type rerankData struct {
|
|
Query string `json:"query"`
|
|
Content string `json:"content"`
|
|
Title *string `json:"title,omitempty"`
|
|
}
|
|
|
|
type rerankResp struct {
|
|
Code int64 `json:"code"`
|
|
Message string `json:"message"`
|
|
Data struct {
|
|
Scores []float64 `json:"scores"`
|
|
TokenUsage int64 `json:"token_usage"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
|
rReq := &rerankReq{
|
|
Datas: make([]rerankData, 0, len(req.Data)),
|
|
RerankModel: r.config.Model,
|
|
}
|
|
sorted := make([]*rerank.Data, 0)
|
|
var flat []*rerank.Data
|
|
visited := map[string]bool{}
|
|
for _, channel := range req.Data {
|
|
if len(channel) == 0 {
|
|
continue
|
|
}
|
|
for _, item := range channel {
|
|
if item == nil || item.Document == nil {
|
|
continue
|
|
}
|
|
if item.Document.ID == "" {
|
|
sorted = append(sorted, &rerank.Data{
|
|
Document: item.Document,
|
|
Score: 1,
|
|
})
|
|
continue
|
|
}
|
|
if visited[item.Document.ID] {
|
|
continue
|
|
}
|
|
visited[item.Document.ID] = true
|
|
flat = append(flat, item)
|
|
}
|
|
}
|
|
|
|
for _, item := range flat {
|
|
rReq.Datas = append(rReq.Datas, rerankData{
|
|
Query: req.Query,
|
|
Content: item.Document.Content,
|
|
})
|
|
}
|
|
|
|
body, err := json.Marshal(rReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(r.prepareRequest(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rResp := rerankResp{}
|
|
if err = json.Unmarshal(respBody, &rResp); err != nil {
|
|
return nil, err
|
|
}
|
|
if rResp.Code != 0 {
|
|
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
|
|
}
|
|
|
|
for i, score := range rResp.Data.Scores {
|
|
sorted = append(sorted, &rerank.Data{
|
|
Document: flat[i].Document,
|
|
Score: score,
|
|
})
|
|
}
|
|
|
|
sort.Slice(sorted, func(i, j int) bool {
|
|
return sorted[i].Score > sorted[j].Score
|
|
})
|
|
|
|
right := len(sorted)
|
|
if req.TopN != nil {
|
|
right = min(right, int(*req.TopN))
|
|
}
|
|
|
|
return &rerank.Response{
|
|
SortedData: sorted[:right],
|
|
TokenUsage: ptr.Of(rResp.Data.TokenUsage),
|
|
}, nil
|
|
}
|
|
|
|
func (r *reranker) prepareRequest(body []byte) *http.Request {
|
|
u := url.URL{
|
|
Scheme: "https",
|
|
Host: r.config.Domain,
|
|
Path: "/api/knowledge/service/rerank",
|
|
}
|
|
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
|
|
req.Header.Add("Accept", "application/json")
|
|
req.Header.Add("Content-Type", "application/json")
|
|
req.Header.Add("Host", domain)
|
|
credential := base.Credentials{
|
|
AccessKeyID: r.config.AK,
|
|
SecretAccessKey: r.config.SK,
|
|
Service: "air",
|
|
Region: r.config.Region,
|
|
}
|
|
req = credential.Sign(req)
|
|
return req
|
|
}
|