188 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			188 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package vikingdb
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	"sort"
 | |
| 
 | |
| 	"github.com/volcengine/volc-sdk-golang/base"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | |
| )
 | |
| 
 | |
| type Config struct {
 | |
| 	AK     string
 | |
| 	SK     string
 | |
| 	Domain string
 | |
| 	Model  string
 | |
| 	Region string // default cn-north-1
 | |
| }
 | |
| 
 | |
| func NewReranker(config *Config) rerank.Reranker {
 | |
| 	if config.Region == "" {
 | |
| 		config.Region = "cn-north-1"
 | |
| 	}
 | |
| 	if config.Domain == "" {
 | |
| 		config.Domain = domain
 | |
| 	}
 | |
| 	if config.Model == "" {
 | |
| 		config.Model = defaultModel
 | |
| 	}
 | |
| 	return &reranker{config: config}
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	domain       = "api-knowledgebase.mlp.cn-beijing.volces.com"
 | |
| 	defaultModel = "base-multilingual-rerank"
 | |
| )
 | |
| 
 | |
| type reranker struct {
 | |
| 	config *Config
 | |
| }
 | |
| 
 | |
| type rerankReq struct {
 | |
| 	Datas       []rerankData `json:"datas"`
 | |
| 	RerankModel string       `json:"rerank_model"`
 | |
| }
 | |
| 
 | |
| type rerankData struct {
 | |
| 	Query   string  `json:"query"`
 | |
| 	Content string  `json:"content"`
 | |
| 	Title   *string `json:"title,omitempty"`
 | |
| }
 | |
| 
 | |
| type rerankResp struct {
 | |
| 	Code    int64  `json:"code"`
 | |
| 	Message string `json:"message"`
 | |
| 	Data    struct {
 | |
| 		Scores     []float64 `json:"scores"`
 | |
| 		TokenUsage int64     `json:"token_usage"`
 | |
| 	} `json:"data"`
 | |
| }
 | |
| 
 | |
| func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
 | |
| 	rReq := &rerankReq{
 | |
| 		Datas:       make([]rerankData, 0, len(req.Data)),
 | |
| 		RerankModel: r.config.Model,
 | |
| 	}
 | |
| 	sorted := make([]*rerank.Data, 0)
 | |
| 	var flat []*rerank.Data
 | |
| 	visited := map[string]bool{}
 | |
| 	for _, channel := range req.Data {
 | |
| 		if len(channel) == 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 		for _, item := range channel {
 | |
| 			if item == nil || item.Document == nil {
 | |
| 				continue
 | |
| 			}
 | |
| 			if item.Document.ID == "" {
 | |
| 				sorted = append(sorted, &rerank.Data{
 | |
| 					Document: item.Document,
 | |
| 					Score:    1,
 | |
| 				})
 | |
| 				continue
 | |
| 			}
 | |
| 			if visited[item.Document.ID] {
 | |
| 				continue
 | |
| 			}
 | |
| 			visited[item.Document.ID] = true
 | |
| 			flat = append(flat, item)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for _, item := range flat {
 | |
| 		rReq.Datas = append(rReq.Datas, rerankData{
 | |
| 			Query:   req.Query,
 | |
| 			Content: item.Document.Content,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	body, err := json.Marshal(rReq)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	resp, err := http.DefaultClient.Do(r.prepareRequest(body))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	defer resp.Body.Close()
 | |
| 	respBody, err := io.ReadAll(resp.Body)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	rResp := rerankResp{}
 | |
| 	if err = json.Unmarshal(respBody, &rResp); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if rResp.Code != 0 {
 | |
| 		return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
 | |
| 	}
 | |
| 
 | |
| 	for i, score := range rResp.Data.Scores {
 | |
| 		sorted = append(sorted, &rerank.Data{
 | |
| 			Document: flat[i].Document,
 | |
| 			Score:    score,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(sorted, func(i, j int) bool {
 | |
| 		return sorted[i].Score > sorted[j].Score
 | |
| 	})
 | |
| 
 | |
| 	right := len(sorted)
 | |
| 	if req.TopN != nil {
 | |
| 		right = min(right, int(*req.TopN))
 | |
| 	}
 | |
| 
 | |
| 	return &rerank.Response{
 | |
| 		SortedData: sorted[:right],
 | |
| 		TokenUsage: ptr.Of(rResp.Data.TokenUsage),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (r *reranker) prepareRequest(body []byte) *http.Request {
 | |
| 	u := url.URL{
 | |
| 		Scheme: "https",
 | |
| 		Host:   r.config.Domain,
 | |
| 		Path:   "/api/knowledge/service/rerank",
 | |
| 	}
 | |
| 	req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
 | |
| 	req.Header.Add("Accept", "application/json")
 | |
| 	req.Header.Add("Content-Type", "application/json")
 | |
| 	req.Header.Add("Host", domain)
 | |
| 	credential := base.Credentials{
 | |
| 		AccessKeyID:     r.config.AK,
 | |
| 		SecretAccessKey: r.config.SK,
 | |
| 		Service:         "air",
 | |
| 		Region:          r.config.Region,
 | |
| 	}
 | |
| 	req = credential.Sign(req)
 | |
| 	return req
 | |
| }
 |