feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
@@ -0,0 +1,331 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Service *vikingdb.VikingDBService
|
||||
|
||||
IndexingConfig *VikingIndexingConfig
|
||||
EmbeddingConfig *VikingEmbeddingConfig
|
||||
|
||||
// TODO: cache viking collection & index client
|
||||
}
|
||||
|
||||
type VikingIndexingConfig struct {
|
||||
// vector index config
|
||||
Type IndexType // default: hnsw / hnsw_hybrid
|
||||
Distance *IndexDistance // default: ip
|
||||
Quant *IndexQuant // default: int8
|
||||
HnswM *int64 // default: 20
|
||||
HnswCef *int64 // default: 400
|
||||
HnswSef *int64 // default: 800
|
||||
|
||||
// others
|
||||
CpuQuota int64 // default: 2
|
||||
ShardCount int64 // default: 1
|
||||
}
|
||||
|
||||
type VikingEmbeddingConfig struct {
|
||||
UseVikingEmbedding bool
|
||||
EnableHybrid bool
|
||||
|
||||
// viking embedding config
|
||||
ModelName VikingEmbeddingModelName
|
||||
ModelVersion *string
|
||||
DenseWeight *float64
|
||||
|
||||
// builtin embedding config
|
||||
BuiltinEmbedding embedding.Embedder
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
||||
if config.Service == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
|
||||
}
|
||||
if config.EmbeddingConfig == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
|
||||
}
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
|
||||
}
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
|
||||
}
|
||||
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
|
||||
}
|
||||
if config.EmbeddingConfig.UseVikingEmbedding &&
|
||||
config.EmbeddingConfig.EnableHybrid &&
|
||||
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
|
||||
}
|
||||
if config.IndexingConfig == nil {
|
||||
config.IndexingConfig = &VikingIndexingConfig{}
|
||||
}
|
||||
if config.IndexingConfig.Type == "" {
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
|
||||
config.IndexingConfig.Type = IndexTypeHNSW
|
||||
} else {
|
||||
config.IndexingConfig.Type = IndexTypeHNSWHybrid
|
||||
}
|
||||
}
|
||||
if config.IndexingConfig.Distance == nil {
|
||||
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
|
||||
}
|
||||
if config.IndexingConfig.Quant == nil {
|
||||
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
|
||||
}
|
||||
if config.IndexingConfig.HnswM == nil {
|
||||
config.IndexingConfig.HnswM = ptr.Of(int64(20))
|
||||
}
|
||||
if config.IndexingConfig.HnswCef == nil {
|
||||
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
|
||||
}
|
||||
if config.IndexingConfig.HnswSef == nil {
|
||||
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
|
||||
}
|
||||
if config.IndexingConfig.CpuQuota == 0 {
|
||||
config.IndexingConfig.CpuQuota = 2
|
||||
}
|
||||
if config.IndexingConfig.ShardCount == 0 {
|
||||
config.IndexingConfig.ShardCount = 1
|
||||
}
|
||||
|
||||
return &manager{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if err := m.createCollection(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.createIndex(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
|
||||
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
|
||||
if !strings.Contains(err.Error(), errIndexNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
|
||||
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
collection, err := m.config.Service.GetCollection(collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &vkSearchStore{manager: m, collection: collection}, nil
|
||||
}
|
||||
|
||||
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
svc := m.config.Service
|
||||
|
||||
collection, err := svc.GetCollection(req.CollectionName)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
||||
return err
|
||||
}
|
||||
} else if collection != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields, vopts, err := m.mapFields(req.Fields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if vopts != nil {
|
||||
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
|
||||
} else {
|
||||
_, err = svc.CreateCollection(req.CollectionName, fields, "")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
svc := m.config.Service
|
||||
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), errIndexNotFound) {
|
||||
return err
|
||||
}
|
||||
} else if index != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
vectorIndex := &vikingdb.VectorIndexParams{
|
||||
IndexType: string(m.config.IndexingConfig.Type),
|
||||
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
|
||||
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
|
||||
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
|
||||
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
|
||||
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
|
||||
}
|
||||
|
||||
opts := vikingdb.NewIndexOptions().
|
||||
SetVectorIndex(vectorIndex).
|
||||
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
|
||||
SetShardCount(m.config.IndexingConfig.ShardCount)
|
||||
|
||||
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
|
||||
var (
|
||||
foundID bool
|
||||
foundCreatorID bool
|
||||
dstFields = make([]vikingdb.Field, 0, len(srcFields))
|
||||
vectorizeOpts []*vikingdb.VectorizeTuple
|
||||
embConfig = m.config.EmbeddingConfig
|
||||
)
|
||||
|
||||
for _, srcField := range srcFields {
|
||||
switch srcField.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
default:
|
||||
}
|
||||
|
||||
if srcField.Indexing {
|
||||
if srcField.Type != searchstore.FieldTypeText {
|
||||
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
|
||||
}
|
||||
if embConfig.UseVikingEmbedding {
|
||||
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name))
|
||||
if embConfig.EnableHybrid {
|
||||
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name))
|
||||
}
|
||||
vectorizeOpts = append(vectorizeOpts, vt)
|
||||
} else {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: denseFieldName(srcField.Name),
|
||||
FieldType: vikingdb.Vector,
|
||||
DefaultVal: nil,
|
||||
Dim: m.getDims(),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
dstField := vikingdb.Field{
|
||||
FieldName: srcField.Name,
|
||||
IsPrimaryKey: srcField.IsPrimary,
|
||||
}
|
||||
switch srcField.Type {
|
||||
case searchstore.FieldTypeInt64:
|
||||
dstField.FieldType = vikingdb.Int64
|
||||
case searchstore.FieldTypeText:
|
||||
dstField.FieldType = vikingdb.Text
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
dstField.FieldType = vikingdb.Vector
|
||||
dstField.Dim = m.getDims()
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
dstField.FieldType = vikingdb.Sparse_Vector
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
|
||||
}
|
||||
dstFields = append(dstFields, dstField)
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: searchstore.FieldID,
|
||||
FieldType: vikingdb.Int64,
|
||||
IsPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if !foundCreatorID {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: searchstore.FieldCreatorID,
|
||||
FieldType: vikingdb.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
return dstFields, vectorizeOpts, nil
|
||||
}
|
||||
|
||||
func (m *manager) newVectorizeModelConf(fieldName string) *vikingdb.VectorizeModelConf {
|
||||
embConfig := m.config.EmbeddingConfig
|
||||
vmc := vikingdb.NewVectorizeModelConf().
|
||||
SetTextField(fieldName).
|
||||
SetModelName(string(embConfig.ModelName)).
|
||||
SetDim(m.getDims())
|
||||
if embConfig.ModelVersion != nil {
|
||||
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
|
||||
}
|
||||
return vmc
|
||||
}
|
||||
|
||||
func (m *manager) getDims() int64 {
|
||||
if m.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
return m.config.EmbeddingConfig.ModelName.Dimensions()
|
||||
}
|
||||
|
||||
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
|
||||
}
|
||||
Reference in New Issue
Block a user