coze-studio/backend/infra/impl/document/searchstore/vikingdb/vikingdb_manager.go

334 lines
9.5 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"fmt"
"strings"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ManagerConfig struct {
Service *vikingdb.VikingDBService
IndexingConfig *VikingIndexingConfig
EmbeddingConfig *VikingEmbeddingConfig
// TODO: cache viking collection & index client
}
type VikingIndexingConfig struct {
// vector index config
Type IndexType // default: hnsw / hnsw_hybrid
Distance *IndexDistance // default: ip
Quant *IndexQuant // default: int8
HnswM *int64 // default: 20
HnswCef *int64 // default: 400
HnswSef *int64 // default: 800
// others
CpuQuota int64 // default: 2
ShardCount int64 // default: 1
}
type VikingEmbeddingConfig struct {
UseVikingEmbedding bool
EnableHybrid bool
// viking embedding config
ModelName VikingEmbeddingModelName
ModelVersion *string
DenseWeight *float64
// builtin embedding config
BuiltinEmbedding embedding.Embedder
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Service == nil {
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
}
if config.EmbeddingConfig == nil {
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
}
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
}
if config.EmbeddingConfig.UseVikingEmbedding &&
config.EmbeddingConfig.EnableHybrid &&
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
}
if config.IndexingConfig == nil {
config.IndexingConfig = &VikingIndexingConfig{}
}
if config.IndexingConfig.Type == "" {
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
config.IndexingConfig.Type = IndexTypeHNSW
} else {
config.IndexingConfig.Type = IndexTypeHNSWHybrid
}
}
if config.IndexingConfig.Distance == nil {
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
}
if config.IndexingConfig.Quant == nil {
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
}
if config.IndexingConfig.HnswM == nil {
config.IndexingConfig.HnswM = ptr.Of(int64(20))
}
if config.IndexingConfig.HnswCef == nil {
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
}
if config.IndexingConfig.HnswSef == nil {
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
}
if config.IndexingConfig.CpuQuota == 0 {
config.IndexingConfig.CpuQuota = 2
}
if config.IndexingConfig.ShardCount == 0 {
config.IndexingConfig.ShardCount = 1
}
return &manager{
config: config,
}, nil
}
type manager struct {
config *ManagerConfig
}
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return err
}
if err := m.createIndex(ctx, req); err != nil {
return err
}
return nil
}
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
}
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
}
return nil
}
func (m *manager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
collection, err := m.config.Service.GetCollection(collectionName)
if err != nil {
return nil, err
}
return &vkSearchStore{manager: m, collection: collection}, nil
}
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
collection, err := svc.GetCollection(req.CollectionName)
if err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
} else if collection != nil {
return nil
}
fields, vopts, err := m.mapFields(req.Fields)
if err != nil {
return err
}
if vopts != nil {
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
} else {
_, err = svc.CreateCollection(req.CollectionName, fields, "")
}
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
return nil
}
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
if err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
} else if index != nil {
return nil
}
vectorIndex := &vikingdb.VectorIndexParams{
IndexType: string(m.config.IndexingConfig.Type),
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
}
opts := vikingdb.NewIndexOptions().
SetVectorIndex(vectorIndex).
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
SetShardCount(m.config.IndexingConfig.ShardCount)
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
return nil
}
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
var (
foundID bool
foundCreatorID bool
dstFields = make([]vikingdb.Field, 0, len(srcFields))
vectorizeOpts []*vikingdb.VectorizeTuple
embConfig = m.config.EmbeddingConfig
)
for _, srcField := range srcFields {
switch srcField.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if srcField.Indexing {
if srcField.Type != searchstore.FieldTypeText {
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
}
if embConfig.UseVikingEmbedding {
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name, false))
if embConfig.EnableHybrid {
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name, true))
}
vectorizeOpts = append(vectorizeOpts, vt)
} else {
dstFields = append(dstFields, vikingdb.Field{
FieldName: denseFieldName(srcField.Name),
FieldType: vikingdb.Vector,
DefaultVal: nil,
Dim: m.getDims(),
})
}
}
dstField := vikingdb.Field{
FieldName: srcField.Name,
IsPrimaryKey: srcField.IsPrimary,
}
switch srcField.Type {
case searchstore.FieldTypeInt64:
dstField.FieldType = vikingdb.Int64
case searchstore.FieldTypeText:
dstField.FieldType = vikingdb.Text
case searchstore.FieldTypeDenseVector:
dstField.FieldType = vikingdb.Vector
dstField.Dim = m.getDims()
case searchstore.FieldTypeSparseVector:
dstField.FieldType = vikingdb.Sparse_Vector
default:
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
}
dstFields = append(dstFields, dstField)
}
if !foundID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldID,
FieldType: vikingdb.Int64,
IsPrimaryKey: true,
})
}
if !foundCreatorID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldCreatorID,
FieldType: vikingdb.Int64,
})
}
return dstFields, vectorizeOpts, nil
}
func (m *manager) newVectorizeModelConf(fieldName string, isSparse bool) *vikingdb.VectorizeModelConf {
embConfig := m.config.EmbeddingConfig
vmc := vikingdb.NewVectorizeModelConf().
SetTextField(fieldName).
SetModelName(string(embConfig.ModelName))
if !isSparse {
vmc = vmc.SetDim(m.getDims())
}
if embConfig.ModelVersion != nil {
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
}
return vmc
}
func (m *manager) getDims() int64 {
if m.config.EmbeddingConfig.UseVikingEmbedding {
return m.config.EmbeddingConfig.ModelName.Dimensions()
}
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
}