feat(knowledge): Support ark rerank (#852)

This commit is contained in:
liuyunchao-1998 2025-08-22 14:41:58 +08:00 committed by GitHub
parent 19c63a1150
commit 59c1d9aa03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 106 additions and 57 deletions

View File

@ -59,6 +59,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
@ -145,7 +146,7 @@ func Init(ctx context.Context) (*AppDependencies, error) {
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
}
deps.Reranker = rrf.NewRRFReranker(0)
deps.Reranker = initReranker()
deps.Rewriter, err = initRewriter(ctx)
if err != nil {
@ -207,6 +208,26 @@ func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.M
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
}
func initReranker() rerank.Reranker {
rerankerType := os.Getenv("RERANK_TYPE")
switch rerankerType {
case "vikingdb":
return vikingReranker.NewReranker(getVikingRerankerConfig())
case "rrf":
return rrf.NewRRFReranker(0)
default:
return rrf.NewRRFReranker(0)
}
}
func getVikingRerankerConfig() *vikingReranker.Config {
return &vikingReranker.Config{
AK: os.Getenv("VIKINGDB_RERANK_AK"),
SK: os.Getenv("VIKINGDB_RERANK_SK"),
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
}
}
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
if err != nil {

View File

@ -390,6 +390,10 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
}
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
}
virtualColumnMap := map[string]*entity.TableColumn{}
for i := range doc.TableInfo.Columns {
virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i]
}
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
if err != nil {
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
@ -423,6 +427,32 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
prefix := "sql:" + sql + ";result:"
d.Content = prefix + string(byteData)
} else {
transferMap := map[string]string{}
for cName, val := range resp.ResultSet.Rows[i] {
column, found := virtualColumnMap[cName]
if !found {
logs.CtxInfof(ctx, "column not found, name: %s", cName)
continue
}
columnData, err := convert.ParseAnyData(column, val)
if err != nil {
logs.CtxErrorf(ctx, "parse any data failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeColumnParseFailCode, errorx.KV("msg", err.Error()))
}
if columnData.Type == document.TableColumnTypeString {
columnData.ValString = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
}
if columnData.Type == document.TableColumnTypeImage {
columnData.ValImage = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
}
transferMap[column.Name] = columnData.GetNullableStringValue()
}
byteData, err := sonic.Marshal(transferMap)
if err != nil {
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
return nil, err
}
d.Content = string(byteData)
d.ID = strconv.FormatInt(id, 10)
}
d.WithScore(1)

View File

@ -35,7 +35,8 @@ import (
type Config struct {
AK string
SK string
Domain string
Model string
Region string // default cn-north-1
}
@ -43,6 +44,12 @@ func NewReranker(config *Config) rerank.Reranker {
if config.Region == "" {
config.Region = "cn-north-1"
}
if config.Domain == "" {
config.Domain = domain
}
if config.Model == "" {
config.Model = defaultModel
}
return &reranker{config: config}
}
@ -78,12 +85,32 @@ type rerankResp struct {
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
rReq := &rerankReq{
Datas: make([]rerankData, 0, len(req.Data)),
RerankModel: defaultModel,
RerankModel: r.config.Model,
}
sorted := make([]*rerank.Data, 0)
var flat []*rerank.Data
visited := map[string]bool{}
for _, channel := range req.Data {
flat = append(flat, channel...)
if len(channel) == 0 {
continue
}
for _, item := range channel {
if item == nil || item.Document == nil {
continue
}
if item.Document.ID == "" {
sorted = append(sorted, &rerank.Data{
Document: item.Document,
Score: 1,
})
continue
}
if visited[item.Document.ID] {
continue
}
visited[item.Document.ID] = true
flat = append(flat, item)
}
}
for _, item := range flat {
@ -117,7 +144,6 @@ func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Res
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
}
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
for i, score := range rResp.Data.Scores {
sorted = append(sorted, &rerank.Data{
Document: flat[i].Document,
@ -143,7 +169,7 @@ func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Res
func (r *reranker) prepareRequest(body []byte) *http.Request {
u := url.URL{
Scheme: "https",
Host: domain,
Host: r.config.Domain,
Path: "/api/knowledge/service/rerank",
}
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))

View File

@ -1,48 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
//func TestRun(t *testing.T) {
// AK := os.Getenv("test_ak")
// SK := os.Getenv("test_sk")
//
// r := NewReranker(&Config{
// AK: AK,
// SK: SK,
// })
// resp, err := r.Rerank(context.Background(), &rerank.Request{
// Data: [][]*knowledge.RetrieveSlice{
// {
// {Slice: & entity. Slice {PlainText: "According to the Guinness World Records website, the blue whale is currently the largest animal known in the world, with a body length of up to 30 meters, which is equivalent to the length of a Boeing 737 aircraft"}},
// {Slice: & entity. Slice {PlainText: "An adult female bowhead whale can grow to 22 meters long, while a male whale can grow to 18 meters long"}},
// },
// },
// Query: "What is the largest whale in the world?"
// TopN: nil,
// })
// assert.NoError(t, err)
//
// for _, item := range resp.Sorted {
// fmt.Println(item.Slice.PlainText, item.Score)
// }
// According to the Guinness World Records website, the blue whale is the largest known animal in the world, with a body length of up to 30 meters, which is equivalent to the length of a Boeing 737 aircraft 6209664529733573
// //An adult female bowhead whale can grow up to 22 meters long, while a male whale can grow up to 18 meters 4269785303456468
//
// fmt.Println(resp.TokenUsage)
// // 95
//
//}

View File

@ -139,6 +139,16 @@ export GEMINI_EMBEDDING_LOCATION="" # (string, optional) Gemini
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
# Settings for Rerank
# If you want to use the rerank-related functions in the knowledge base featureYou need to set up the rerank configuration.
export RERANK_TYPE="" # current support `vikingdb`,`rrf`,default:rrf
# vikingdb rerank
export VIKINGDB_RERANK_HOST="" # optional,default:api-knowledgebase.mlp.cn-beijing.volces.com
export VIKINGDB_RERANK_REGION="" # optional,default:cn-north-1
export VIKINGDB_RERANK_AK="" # required
export VIKINGDB_RERANK_SK="" # required
export VIKINGDB_RERANK_MODEL="" # optional,default:base-multilingual-rerank,also support m3-v2-rerank
# Settings for OCR
# If you want to use the OCR-related functions in the knowledge base featureYou need to set up the OCR configuration.
# Currently, Coze Studio has built-in Volcano OCR.

View File

@ -137,6 +137,16 @@ export GEMINI_EMBEDDING_LOCATION="" # (string, optional) Gemini
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
# Settings for Rerank
# If you want to use the rerank-related functions in the knowledge base featureYou need to set up the rerank configuration.
export RERANK_TYPE="" # current support `vikingdb`,`rrf`,default:rrf
# vikingdb rerank
export VIKINGDB_RERANK_HOST="" # optional,default:api-knowledgebase.mlp.cn-beijing.volces.com
export VIKINGDB_RERANK_REGION="" # optional,default:cn-north-1
export VIKINGDB_RERANK_AK="" # required
export VIKINGDB_RERANK_SK="" # required
export VIKINGDB_RERANK_MODEL="" # optional,default:base-multilingual-rerank,also support m3-v2-rerank
# Settings for OCR
# If you want to use the OCR-related functions in the knowledge base featureYou need to set up the OCR configuration.
# Currently, Coze Studio has built-in Volcano OCR.