334 lines
9.5 KiB
Go
334 lines
9.5 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package vikingdb
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
)
|
|
|
|
type ManagerConfig struct {
|
|
Service *vikingdb.VikingDBService
|
|
|
|
IndexingConfig *VikingIndexingConfig
|
|
EmbeddingConfig *VikingEmbeddingConfig
|
|
|
|
// TODO: cache viking collection & index client
|
|
}
|
|
|
|
type VikingIndexingConfig struct {
|
|
// vector index config
|
|
Type IndexType // default: hnsw / hnsw_hybrid
|
|
Distance *IndexDistance // default: ip
|
|
Quant *IndexQuant // default: int8
|
|
HnswM *int64 // default: 20
|
|
HnswCef *int64 // default: 400
|
|
HnswSef *int64 // default: 800
|
|
|
|
// others
|
|
CpuQuota int64 // default: 2
|
|
ShardCount int64 // default: 1
|
|
}
|
|
|
|
type VikingEmbeddingConfig struct {
|
|
UseVikingEmbedding bool
|
|
EnableHybrid bool
|
|
|
|
// viking embedding config
|
|
ModelName VikingEmbeddingModelName
|
|
ModelVersion *string
|
|
DenseWeight *float64
|
|
|
|
// builtin embedding config
|
|
BuiltinEmbedding embedding.Embedder
|
|
}
|
|
|
|
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
|
if config.Service == nil {
|
|
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
|
|
}
|
|
if config.EmbeddingConfig == nil {
|
|
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
|
|
}
|
|
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
|
|
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
|
|
}
|
|
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
|
|
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
|
|
}
|
|
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
|
|
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
|
|
}
|
|
if config.EmbeddingConfig.UseVikingEmbedding &&
|
|
config.EmbeddingConfig.EnableHybrid &&
|
|
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
|
|
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
|
|
}
|
|
if config.IndexingConfig == nil {
|
|
config.IndexingConfig = &VikingIndexingConfig{}
|
|
}
|
|
if config.IndexingConfig.Type == "" {
|
|
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
|
|
config.IndexingConfig.Type = IndexTypeHNSW
|
|
} else {
|
|
config.IndexingConfig.Type = IndexTypeHNSWHybrid
|
|
}
|
|
}
|
|
if config.IndexingConfig.Distance == nil {
|
|
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
|
|
}
|
|
if config.IndexingConfig.Quant == nil {
|
|
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
|
|
}
|
|
if config.IndexingConfig.HnswM == nil {
|
|
config.IndexingConfig.HnswM = ptr.Of(int64(20))
|
|
}
|
|
if config.IndexingConfig.HnswCef == nil {
|
|
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
|
|
}
|
|
if config.IndexingConfig.HnswSef == nil {
|
|
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
|
|
}
|
|
if config.IndexingConfig.CpuQuota == 0 {
|
|
config.IndexingConfig.CpuQuota = 2
|
|
}
|
|
if config.IndexingConfig.ShardCount == 0 {
|
|
config.IndexingConfig.ShardCount = 1
|
|
}
|
|
|
|
return &manager{
|
|
config: config,
|
|
}, nil
|
|
}
|
|
|
|
type manager struct {
|
|
config *ManagerConfig
|
|
}
|
|
|
|
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
|
if err := m.createCollection(ctx, req); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := m.createIndex(ctx, req); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
|
|
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
|
|
if !strings.Contains(err.Error(), errIndexNotFound) {
|
|
return err
|
|
}
|
|
}
|
|
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
|
|
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *manager) GetType() searchstore.SearchStoreType {
|
|
return searchstore.TypeVectorStore
|
|
}
|
|
|
|
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
|
|
collection, err := m.config.Service.GetCollection(collectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &vkSearchStore{manager: m, collection: collection}, nil
|
|
}
|
|
|
|
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
|
svc := m.config.Service
|
|
|
|
collection, err := svc.GetCollection(req.CollectionName)
|
|
if err != nil {
|
|
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
|
return err
|
|
}
|
|
} else if collection != nil {
|
|
return nil
|
|
}
|
|
|
|
fields, vopts, err := m.mapFields(req.Fields)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if vopts != nil {
|
|
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
|
|
} else {
|
|
_, err = svc.CreateCollection(req.CollectionName, fields, "")
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
|
|
svc := m.config.Service
|
|
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
|
|
if err != nil {
|
|
if !strings.Contains(err.Error(), errIndexNotFound) {
|
|
return err
|
|
}
|
|
} else if index != nil {
|
|
return nil
|
|
}
|
|
|
|
vectorIndex := &vikingdb.VectorIndexParams{
|
|
IndexType: string(m.config.IndexingConfig.Type),
|
|
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
|
|
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
|
|
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
|
|
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
|
|
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
|
|
}
|
|
|
|
opts := vikingdb.NewIndexOptions().
|
|
SetVectorIndex(vectorIndex).
|
|
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
|
|
SetShardCount(m.config.IndexingConfig.ShardCount)
|
|
|
|
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
|
|
var (
|
|
foundID bool
|
|
foundCreatorID bool
|
|
dstFields = make([]vikingdb.Field, 0, len(srcFields))
|
|
vectorizeOpts []*vikingdb.VectorizeTuple
|
|
embConfig = m.config.EmbeddingConfig
|
|
)
|
|
|
|
for _, srcField := range srcFields {
|
|
switch srcField.Name {
|
|
case searchstore.FieldID:
|
|
foundID = true
|
|
case searchstore.FieldCreatorID:
|
|
foundCreatorID = true
|
|
default:
|
|
}
|
|
|
|
if srcField.Indexing {
|
|
if srcField.Type != searchstore.FieldTypeText {
|
|
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
|
|
}
|
|
if embConfig.UseVikingEmbedding {
|
|
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name, false))
|
|
if embConfig.EnableHybrid {
|
|
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name, true))
|
|
}
|
|
vectorizeOpts = append(vectorizeOpts, vt)
|
|
} else {
|
|
dstFields = append(dstFields, vikingdb.Field{
|
|
FieldName: denseFieldName(srcField.Name),
|
|
FieldType: vikingdb.Vector,
|
|
DefaultVal: nil,
|
|
Dim: m.getDims(),
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
dstField := vikingdb.Field{
|
|
FieldName: srcField.Name,
|
|
IsPrimaryKey: srcField.IsPrimary,
|
|
}
|
|
switch srcField.Type {
|
|
case searchstore.FieldTypeInt64:
|
|
dstField.FieldType = vikingdb.Int64
|
|
case searchstore.FieldTypeText:
|
|
dstField.FieldType = vikingdb.Text
|
|
case searchstore.FieldTypeDenseVector:
|
|
dstField.FieldType = vikingdb.Vector
|
|
dstField.Dim = m.getDims()
|
|
case searchstore.FieldTypeSparseVector:
|
|
dstField.FieldType = vikingdb.Sparse_Vector
|
|
default:
|
|
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
|
|
}
|
|
dstFields = append(dstFields, dstField)
|
|
}
|
|
|
|
if !foundID {
|
|
dstFields = append(dstFields, vikingdb.Field{
|
|
FieldName: searchstore.FieldID,
|
|
FieldType: vikingdb.Int64,
|
|
IsPrimaryKey: true,
|
|
})
|
|
}
|
|
|
|
if !foundCreatorID {
|
|
dstFields = append(dstFields, vikingdb.Field{
|
|
FieldName: searchstore.FieldCreatorID,
|
|
FieldType: vikingdb.Int64,
|
|
})
|
|
}
|
|
|
|
return dstFields, vectorizeOpts, nil
|
|
}
|
|
|
|
func (m *manager) newVectorizeModelConf(fieldName string, isSparse bool) *vikingdb.VectorizeModelConf {
|
|
embConfig := m.config.EmbeddingConfig
|
|
vmc := vikingdb.NewVectorizeModelConf().
|
|
SetTextField(fieldName).
|
|
SetModelName(string(embConfig.ModelName))
|
|
if !isSparse {
|
|
vmc = vmc.SetDim(m.getDims())
|
|
}
|
|
if embConfig.ModelVersion != nil {
|
|
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
|
|
}
|
|
return vmc
|
|
}
|
|
|
|
func (m *manager) getDims() int64 {
|
|
if m.config.EmbeddingConfig.UseVikingEmbedding {
|
|
return m.config.EmbeddingConfig.ModelName.Dimensions()
|
|
}
|
|
|
|
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
|
|
}
|