feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
70
backend/infra/impl/document/rerank/rrf/rrf.go
Normal file
70
backend/infra/impl/document/rerank/rrf/rrf.go
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func NewRRFReranker(k int64) rerank.Reranker {
|
||||
if k == 0 {
|
||||
k = 60
|
||||
}
|
||||
return &rrfReranker{k}
|
||||
}
|
||||
|
||||
type rrfReranker struct {
|
||||
k int64
|
||||
}
|
||||
|
||||
func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
||||
if req == nil || req.Data == nil || len(req.Data) == 0 {
|
||||
return nil, fmt.Errorf("invalid request: no data provided")
|
||||
}
|
||||
id2Score := make(map[string]float64)
|
||||
id2Data := make(map[string]*rerank.Data)
|
||||
for _, resultList := range req.Data {
|
||||
for rank := range resultList {
|
||||
result := resultList[rank]
|
||||
if result != nil && result.Document != nil {
|
||||
score := 1.0 / (float64(rank) + float64(r.k))
|
||||
if score > id2Score[result.Document.ID] {
|
||||
id2Score[result.Document.ID] = score
|
||||
id2Data[result.Document.ID] = result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var sorted []*rerank.Data
|
||||
for _, data := range id2Data {
|
||||
sorted = append(sorted, data)
|
||||
}
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
|
||||
})
|
||||
topN := int64(len(sorted))
|
||||
if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
|
||||
topN = ptr.From(req.TopN)
|
||||
}
|
||||
|
||||
return &rerank.Response{SortedData: sorted[:topN]}, nil
|
||||
}
|
||||
161
backend/infra/impl/document/rerank/vikingdb/vikingdb.go
Normal file
161
backend/infra/impl/document/rerank/vikingdb/vikingdb.go
Normal file
@@ -0,0 +1,161 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
AK string
|
||||
SK string
|
||||
|
||||
Region string // default cn-north-1
|
||||
}
|
||||
|
||||
func NewReranker(config *Config) rerank.Reranker {
|
||||
if config.Region == "" {
|
||||
config.Region = "cn-north-1"
|
||||
}
|
||||
return &reranker{config: config}
|
||||
}
|
||||
|
||||
const (
|
||||
domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
|
||||
defaultModel = "base-multilingual-rerank"
|
||||
)
|
||||
|
||||
type reranker struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
type rerankReq struct {
|
||||
Datas []rerankData `json:"datas"`
|
||||
RerankModel string `json:"rerank_model"`
|
||||
}
|
||||
|
||||
type rerankData struct {
|
||||
Query string `json:"query"`
|
||||
Content string `json:"content"`
|
||||
Title *string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type rerankResp struct {
|
||||
Code int64 `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
Scores []float64 `json:"scores"`
|
||||
TokenUsage int64 `json:"token_usage"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
||||
rReq := &rerankReq{
|
||||
Datas: make([]rerankData, 0, len(req.Data)),
|
||||
RerankModel: defaultModel,
|
||||
}
|
||||
|
||||
var flat []*rerank.Data
|
||||
for _, channel := range req.Data {
|
||||
flat = append(flat, channel...)
|
||||
}
|
||||
|
||||
for _, item := range flat {
|
||||
rReq.Datas = append(rReq.Datas, rerankData{
|
||||
Query: req.Query,
|
||||
Content: item.Document.Content,
|
||||
})
|
||||
}
|
||||
|
||||
body, err := json.Marshal(rReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(r.prepareRequest(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rResp := rerankResp{}
|
||||
if err = json.Unmarshal(respBody, &rResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rResp.Code != 0 {
|
||||
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
|
||||
}
|
||||
|
||||
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
|
||||
for i, score := range rResp.Data.Scores {
|
||||
sorted = append(sorted, &rerank.Data{
|
||||
Document: flat[i].Document,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Score > sorted[j].Score
|
||||
})
|
||||
|
||||
right := len(sorted)
|
||||
if req.TopN != nil {
|
||||
right = min(right, int(*req.TopN))
|
||||
}
|
||||
|
||||
return &rerank.Response{
|
||||
SortedData: sorted[:right],
|
||||
TokenUsage: ptr.Of(rResp.Data.TokenUsage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *reranker) prepareRequest(body []byte) *http.Request {
|
||||
u := url.URL{
|
||||
Scheme: "https",
|
||||
Host: domain,
|
||||
Path: "/api/knowledge/service/rerank",
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Host", domain)
|
||||
credential := base.Credentials{
|
||||
AccessKeyID: r.config.AK,
|
||||
SecretAccessKey: r.config.SK,
|
||||
Service: "air",
|
||||
Region: r.config.Region,
|
||||
}
|
||||
req = credential.Sign(req)
|
||||
return req
|
||||
}
|
||||
48
backend/infra/impl/document/rerank/vikingdb/vikingdb_test.go
Normal file
48
backend/infra/impl/document/rerank/vikingdb/vikingdb_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
//func TestRun(t *testing.T) {
|
||||
// AK := os.Getenv("test_ak")
|
||||
// SK := os.Getenv("test_sk")
|
||||
//
|
||||
// r := NewReranker(&Config{
|
||||
// AK: AK,
|
||||
// SK: SK,
|
||||
// })
|
||||
// resp, err := r.Rerank(context.Background(), &rerank.Request{
|
||||
// Data: [][]*knowledge.RetrieveSlice{
|
||||
// {
|
||||
// {Slice: &entity.Slice{PlainText: "吉尼斯世界纪录网站数据显示,蓝鲸是目前已知世界上最大的动物,体长可达30米,相当于一架波音737飞机的长度"}},
|
||||
// {Slice: &entity.Slice{PlainText: "一头成年雌性弓头鲸可以长到22米长,而一头雄性鲸鱼可以长到18米长"}},
|
||||
// },
|
||||
// },
|
||||
// Query: "世界上最大的鲸鱼是什么?",
|
||||
// TopN: nil,
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
//
|
||||
// for _, item := range resp.Sorted {
|
||||
// fmt.Println(item.Slice.PlainText, item.Score)
|
||||
// }
|
||||
// // 吉尼斯世界纪录网站数据显示,蓝鲸是目前已知世界上最大的动物,体长可达30米,相当于一架波音737飞机的长度 0.6209664529733573
|
||||
// // 一头成年雌性弓头鲸可以长到22米长,而一头雄性鲸鱼可以长到18米长 0.4269785303456468
|
||||
//
|
||||
// fmt.Println(resp.TokenUsage)
|
||||
// // 95
|
||||
//
|
||||
//}
|
||||
Reference in New Issue
Block a user