feat(knowledge): Support ark rerank (#852)
This commit is contained in:
parent
19c63a1150
commit
59c1d9aa03
|
|
@ -59,6 +59,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||||
|
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/vikingdb"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
||||||
|
|
@ -145,7 +146,7 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||||
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
|
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
deps.Reranker = rrf.NewRRFReranker(0)
|
deps.Reranker = initReranker()
|
||||||
|
|
||||||
deps.Rewriter, err = initRewriter(ctx)
|
deps.Rewriter, err = initRewriter(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -207,6 +208,26 @@ func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.M
|
||||||
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initReranker() rerank.Reranker {
|
||||||
|
rerankerType := os.Getenv("RERANK_TYPE")
|
||||||
|
switch rerankerType {
|
||||||
|
case "vikingdb":
|
||||||
|
return vikingReranker.NewReranker(getVikingRerankerConfig())
|
||||||
|
case "rrf":
|
||||||
|
return rrf.NewRRFReranker(0)
|
||||||
|
default:
|
||||||
|
return rrf.NewRRFReranker(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func getVikingRerankerConfig() *vikingReranker.Config {
|
||||||
|
return &vikingReranker.Config{
|
||||||
|
AK: os.Getenv("VIKINGDB_RERANK_AK"),
|
||||||
|
SK: os.Getenv("VIKINGDB_RERANK_SK"),
|
||||||
|
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
|
||||||
|
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
|
||||||
|
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
|
||||||
|
}
|
||||||
|
}
|
||||||
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
|
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
|
||||||
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
|
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -390,6 +390,10 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||||
}
|
}
|
||||||
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
|
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
|
||||||
}
|
}
|
||||||
|
virtualColumnMap := map[string]*entity.TableColumn{}
|
||||||
|
for i := range doc.TableInfo.Columns {
|
||||||
|
virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i]
|
||||||
|
}
|
||||||
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
|
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
|
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
|
||||||
|
|
@ -423,6 +427,32 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||||
prefix := "sql:" + sql + ";result:"
|
prefix := "sql:" + sql + ";result:"
|
||||||
d.Content = prefix + string(byteData)
|
d.Content = prefix + string(byteData)
|
||||||
} else {
|
} else {
|
||||||
|
transferMap := map[string]string{}
|
||||||
|
for cName, val := range resp.ResultSet.Rows[i] {
|
||||||
|
column, found := virtualColumnMap[cName]
|
||||||
|
if !found {
|
||||||
|
logs.CtxInfof(ctx, "column not found, name: %s", cName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
columnData, err := convert.ParseAnyData(column, val)
|
||||||
|
if err != nil {
|
||||||
|
logs.CtxErrorf(ctx, "parse any data failed: %v", err)
|
||||||
|
return nil, errorx.New(errno.ErrKnowledgeColumnParseFailCode, errorx.KV("msg", err.Error()))
|
||||||
|
}
|
||||||
|
if columnData.Type == document.TableColumnTypeString {
|
||||||
|
columnData.ValString = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
|
||||||
|
}
|
||||||
|
if columnData.Type == document.TableColumnTypeImage {
|
||||||
|
columnData.ValImage = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
|
||||||
|
}
|
||||||
|
transferMap[column.Name] = columnData.GetNullableStringValue()
|
||||||
|
}
|
||||||
|
byteData, err := sonic.Marshal(transferMap)
|
||||||
|
if err != nil {
|
||||||
|
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d.Content = string(byteData)
|
||||||
d.ID = strconv.FormatInt(id, 10)
|
d.ID = strconv.FormatInt(id, 10)
|
||||||
}
|
}
|
||||||
d.WithScore(1)
|
d.WithScore(1)
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,8 @@ import (
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AK string
|
AK string
|
||||||
SK string
|
SK string
|
||||||
|
Domain string
|
||||||
|
Model string
|
||||||
Region string // default cn-north-1
|
Region string // default cn-north-1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -43,6 +44,12 @@ func NewReranker(config *Config) rerank.Reranker {
|
||||||
if config.Region == "" {
|
if config.Region == "" {
|
||||||
config.Region = "cn-north-1"
|
config.Region = "cn-north-1"
|
||||||
}
|
}
|
||||||
|
if config.Domain == "" {
|
||||||
|
config.Domain = domain
|
||||||
|
}
|
||||||
|
if config.Model == "" {
|
||||||
|
config.Model = defaultModel
|
||||||
|
}
|
||||||
return &reranker{config: config}
|
return &reranker{config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -78,12 +85,32 @@ type rerankResp struct {
|
||||||
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
||||||
rReq := &rerankReq{
|
rReq := &rerankReq{
|
||||||
Datas: make([]rerankData, 0, len(req.Data)),
|
Datas: make([]rerankData, 0, len(req.Data)),
|
||||||
RerankModel: defaultModel,
|
RerankModel: r.config.Model,
|
||||||
}
|
}
|
||||||
|
sorted := make([]*rerank.Data, 0)
|
||||||
var flat []*rerank.Data
|
var flat []*rerank.Data
|
||||||
|
visited := map[string]bool{}
|
||||||
for _, channel := range req.Data {
|
for _, channel := range req.Data {
|
||||||
flat = append(flat, channel...)
|
if len(channel) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, item := range channel {
|
||||||
|
if item == nil || item.Document == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if item.Document.ID == "" {
|
||||||
|
sorted = append(sorted, &rerank.Data{
|
||||||
|
Document: item.Document,
|
||||||
|
Score: 1,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if visited[item.Document.ID] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited[item.Document.ID] = true
|
||||||
|
flat = append(flat, item)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range flat {
|
for _, item := range flat {
|
||||||
|
|
@ -117,7 +144,6 @@ func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Res
|
||||||
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
|
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
|
|
||||||
for i, score := range rResp.Data.Scores {
|
for i, score := range rResp.Data.Scores {
|
||||||
sorted = append(sorted, &rerank.Data{
|
sorted = append(sorted, &rerank.Data{
|
||||||
Document: flat[i].Document,
|
Document: flat[i].Document,
|
||||||
|
|
@ -143,7 +169,7 @@ func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Res
|
||||||
func (r *reranker) prepareRequest(body []byte) *http.Request {
|
func (r *reranker) prepareRequest(body []byte) *http.Request {
|
||||||
u := url.URL{
|
u := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: domain,
|
Host: r.config.Domain,
|
||||||
Path: "/api/knowledge/service/rerank",
|
Path: "/api/knowledge/service/rerank",
|
||||||
}
|
}
|
||||||
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
|
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
|
||||||
|
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2025 coze-dev Authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package vikingdb
|
|
||||||
|
|
||||||
//func TestRun(t *testing.T) {
|
|
||||||
// AK := os.Getenv("test_ak")
|
|
||||||
// SK := os.Getenv("test_sk")
|
|
||||||
//
|
|
||||||
// r := NewReranker(&Config{
|
|
||||||
// AK: AK,
|
|
||||||
// SK: SK,
|
|
||||||
// })
|
|
||||||
// resp, err := r.Rerank(context.Background(), &rerank.Request{
|
|
||||||
// Data: [][]*knowledge.RetrieveSlice{
|
|
||||||
// {
|
|
||||||
// {Slice: & entity. Slice {PlainText: "According to the Guinness World Records website, the blue whale is currently the largest animal known in the world, with a body length of up to 30 meters, which is equivalent to the length of a Boeing 737 aircraft"}},
|
|
||||||
// {Slice: & entity. Slice {PlainText: "An adult female bowhead whale can grow to 22 meters long, while a male whale can grow to 18 meters long"}},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Query: "What is the largest whale in the world?"
|
|
||||||
// TopN: nil,
|
|
||||||
// })
|
|
||||||
// assert.NoError(t, err)
|
|
||||||
//
|
|
||||||
// for _, item := range resp.Sorted {
|
|
||||||
// fmt.Println(item.Slice.PlainText, item.Score)
|
|
||||||
// }
|
|
||||||
// According to the Guinness World Records website, the blue whale is the largest known animal in the world, with a body length of up to 30 meters, which is equivalent to the length of a Boeing 737 aircraft 6209664529733573
|
|
||||||
// //An adult female bowhead whale can grow up to 22 meters long, while a male whale can grow up to 18 meters 4269785303456468
|
|
||||||
//
|
|
||||||
// fmt.Println(resp.TokenUsage)
|
|
||||||
// // 95
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
|
|
@ -139,6 +139,16 @@ export GEMINI_EMBEDDING_LOCATION="" # (string, optional) Gemini
|
||||||
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
|
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
|
||||||
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
|
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
|
||||||
|
|
||||||
|
# Settings for Rerank
|
||||||
|
# If you want to use the rerank-related functions in the knowledge base feature,You need to set up the rerank configuration.
|
||||||
|
export RERANK_TYPE="" # current support `vikingdb`,`rrf`,default:rrf
|
||||||
|
# vikingdb rerank
|
||||||
|
export VIKINGDB_RERANK_HOST="" # optional,default:api-knowledgebase.mlp.cn-beijing.volces.com
|
||||||
|
export VIKINGDB_RERANK_REGION="" # optional,default:cn-north-1
|
||||||
|
export VIKINGDB_RERANK_AK="" # required
|
||||||
|
export VIKINGDB_RERANK_SK="" # required
|
||||||
|
export VIKINGDB_RERANK_MODEL="" # optional,default:base-multilingual-rerank,also support m3-v2-rerank
|
||||||
|
|
||||||
# Settings for OCR
|
# Settings for OCR
|
||||||
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
|
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
|
||||||
# Currently, Coze Studio has built-in Volcano OCR.
|
# Currently, Coze Studio has built-in Volcano OCR.
|
||||||
|
|
|
||||||
|
|
@ -137,6 +137,16 @@ export GEMINI_EMBEDDING_LOCATION="" # (string, optional) Gemini
|
||||||
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
|
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
|
||||||
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
|
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
|
||||||
|
|
||||||
|
# Settings for Rerank
|
||||||
|
# If you want to use the rerank-related functions in the knowledge base feature,You need to set up the rerank configuration.
|
||||||
|
export RERANK_TYPE="" # current support `vikingdb`,`rrf`,default:rrf
|
||||||
|
# vikingdb rerank
|
||||||
|
export VIKINGDB_RERANK_HOST="" # optional,default:api-knowledgebase.mlp.cn-beijing.volces.com
|
||||||
|
export VIKINGDB_RERANK_REGION="" # optional,default:cn-north-1
|
||||||
|
export VIKINGDB_RERANK_AK="" # required
|
||||||
|
export VIKINGDB_RERANK_SK="" # required
|
||||||
|
export VIKINGDB_RERANK_MODEL="" # optional,default:base-multilingual-rerank,also support m3-v2-rerank
|
||||||
|
|
||||||
# Settings for OCR
|
# Settings for OCR
|
||||||
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
|
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
|
||||||
# Currently, Coze Studio has built-in Volcano OCR.
|
# Currently, Coze Studio has built-in Volcano OCR.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue