feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rrf
import (
"context"
"fmt"
"sort"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func NewRRFReranker(k int64) rerank.Reranker {
if k == 0 {
k = 60
}
return &rrfReranker{k}
}
type rrfReranker struct {
k int64
}
func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
if req == nil || req.Data == nil || len(req.Data) == 0 {
return nil, fmt.Errorf("invalid request: no data provided")
}
id2Score := make(map[string]float64)
id2Data := make(map[string]*rerank.Data)
for _, resultList := range req.Data {
for rank := range resultList {
result := resultList[rank]
if result != nil && result.Document != nil {
score := 1.0 / (float64(rank) + float64(r.k))
if score > id2Score[result.Document.ID] {
id2Score[result.Document.ID] = score
id2Data[result.Document.ID] = result
}
}
}
}
var sorted []*rerank.Data
for _, data := range id2Data {
sorted = append(sorted, data)
}
sort.Slice(sorted, func(i, j int) bool {
return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
})
topN := int64(len(sorted))
if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
topN = ptr.From(req.TopN)
}
return &rerank.Response{SortedData: sorted[:topN]}, nil
}

View File

@@ -0,0 +1,161 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Config struct {
AK string
SK string
Region string // default cn-north-1
}
func NewReranker(config *Config) rerank.Reranker {
if config.Region == "" {
config.Region = "cn-north-1"
}
return &reranker{config: config}
}
const (
domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
defaultModel = "base-multilingual-rerank"
)
type reranker struct {
config *Config
}
type rerankReq struct {
Datas []rerankData `json:"datas"`
RerankModel string `json:"rerank_model"`
}
type rerankData struct {
Query string `json:"query"`
Content string `json:"content"`
Title *string `json:"title,omitempty"`
}
type rerankResp struct {
Code int64 `json:"code"`
Message string `json:"message"`
Data struct {
Scores []float64 `json:"scores"`
TokenUsage int64 `json:"token_usage"`
} `json:"data"`
}
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
rReq := &rerankReq{
Datas: make([]rerankData, 0, len(req.Data)),
RerankModel: defaultModel,
}
var flat []*rerank.Data
for _, channel := range req.Data {
flat = append(flat, channel...)
}
for _, item := range flat {
rReq.Datas = append(rReq.Datas, rerankData{
Query: req.Query,
Content: item.Document.Content,
})
}
body, err := json.Marshal(rReq)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(r.prepareRequest(body))
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
rResp := rerankResp{}
if err = json.Unmarshal(respBody, &rResp); err != nil {
return nil, err
}
if rResp.Code != 0 {
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
}
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
for i, score := range rResp.Data.Scores {
sorted = append(sorted, &rerank.Data{
Document: flat[i].Document,
Score: score,
})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Score > sorted[j].Score
})
right := len(sorted)
if req.TopN != nil {
right = min(right, int(*req.TopN))
}
return &rerank.Response{
SortedData: sorted[:right],
TokenUsage: ptr.Of(rResp.Data.TokenUsage),
}, nil
}
func (r *reranker) prepareRequest(body []byte) *http.Request {
u := url.URL{
Scheme: "https",
Host: domain,
Path: "/api/knowledge/service/rerank",
}
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Host", domain)
credential := base.Credentials{
AccessKeyID: r.config.AK,
SecretAccessKey: r.config.SK,
Service: "air",
Region: r.config.Region,
}
req = credential.Sign(req)
return req
}

View File

@@ -0,0 +1,48 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
//func TestRun(t *testing.T) {
// AK := os.Getenv("test_ak")
// SK := os.Getenv("test_sk")
//
// r := NewReranker(&Config{
// AK: AK,
// SK: SK,
// })
// resp, err := r.Rerank(context.Background(), &rerank.Request{
// Data: [][]*knowledge.RetrieveSlice{
// {
// {Slice: &entity.Slice{PlainText: "吉尼斯世界纪录网站数据显示蓝鲸是目前已知世界上最大的动物体长可达30米相当于一架波音737飞机的长度"}},
// {Slice: &entity.Slice{PlainText: "一头成年雌性弓头鲸可以长到22米长而一头雄性鲸鱼可以长到18米长"}},
// },
// },
// Query: "世界上最大的鲸鱼是什么?",
// TopN: nil,
// })
// assert.NoError(t, err)
//
// for _, item := range resp.Sorted {
// fmt.Println(item.Slice.PlainText, item.Score)
// }
// // 吉尼斯世界纪录网站数据显示蓝鲸是目前已知世界上最大的动物体长可达30米相当于一架波音737飞机的长度 0.6209664529733573
// // 一头成年雌性弓头鲸可以长到22米长而一头雄性鲸鱼可以长到18米长 0.4269785303456468
//
// fmt.Println(resp.TokenUsage)
// // 95
//
//}