334 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			334 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package vikingdb
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/volcengine/volc-sdk-golang/service/vikingdb"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | 
						|
)
 | 
						|
 | 
						|
type ManagerConfig struct {
 | 
						|
	Service *vikingdb.VikingDBService
 | 
						|
 | 
						|
	IndexingConfig  *VikingIndexingConfig
 | 
						|
	EmbeddingConfig *VikingEmbeddingConfig
 | 
						|
 | 
						|
	// TODO: cache viking collection & index client
 | 
						|
}
 | 
						|
 | 
						|
type VikingIndexingConfig struct {
 | 
						|
	// vector index config
 | 
						|
	Type     IndexType      // default: hnsw / hnsw_hybrid
 | 
						|
	Distance *IndexDistance // default: ip
 | 
						|
	Quant    *IndexQuant    // default: int8
 | 
						|
	HnswM    *int64         // default: 20
 | 
						|
	HnswCef  *int64         // default: 400
 | 
						|
	HnswSef  *int64         // default: 800
 | 
						|
 | 
						|
	// others
 | 
						|
	CpuQuota   int64 // default: 2
 | 
						|
	ShardCount int64 // default: 1
 | 
						|
}
 | 
						|
 | 
						|
type VikingEmbeddingConfig struct {
 | 
						|
	UseVikingEmbedding bool
 | 
						|
	EnableHybrid       bool
 | 
						|
 | 
						|
	// viking embedding config
 | 
						|
	ModelName    VikingEmbeddingModelName
 | 
						|
	ModelVersion *string
 | 
						|
	DenseWeight  *float64
 | 
						|
 | 
						|
	// builtin embedding config
 | 
						|
	BuiltinEmbedding embedding.Embedder
 | 
						|
}
 | 
						|
 | 
						|
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
 | 
						|
	if config.Service == nil {
 | 
						|
		return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
 | 
						|
	}
 | 
						|
	if config.EmbeddingConfig == nil {
 | 
						|
		return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
 | 
						|
	}
 | 
						|
	if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
 | 
						|
		return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
 | 
						|
	}
 | 
						|
	if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
 | 
						|
		return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
 | 
						|
	}
 | 
						|
	if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
 | 
						|
		return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
 | 
						|
	}
 | 
						|
	if config.EmbeddingConfig.UseVikingEmbedding &&
 | 
						|
		config.EmbeddingConfig.EnableHybrid &&
 | 
						|
		config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
 | 
						|
		return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
 | 
						|
	}
 | 
						|
	if config.IndexingConfig == nil {
 | 
						|
		config.IndexingConfig = &VikingIndexingConfig{}
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.Type == "" {
 | 
						|
		if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
 | 
						|
			config.IndexingConfig.Type = IndexTypeHNSW
 | 
						|
		} else {
 | 
						|
			config.IndexingConfig.Type = IndexTypeHNSWHybrid
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.Distance == nil {
 | 
						|
		config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.Quant == nil {
 | 
						|
		config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.HnswM == nil {
 | 
						|
		config.IndexingConfig.HnswM = ptr.Of(int64(20))
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.HnswCef == nil {
 | 
						|
		config.IndexingConfig.HnswCef = ptr.Of(int64(400))
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.HnswSef == nil {
 | 
						|
		config.IndexingConfig.HnswSef = ptr.Of(int64(800))
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.CpuQuota == 0 {
 | 
						|
		config.IndexingConfig.CpuQuota = 2
 | 
						|
	}
 | 
						|
	if config.IndexingConfig.ShardCount == 0 {
 | 
						|
		config.IndexingConfig.ShardCount = 1
 | 
						|
	}
 | 
						|
 | 
						|
	return &manager{
 | 
						|
		config: config,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
type manager struct {
 | 
						|
	config *ManagerConfig
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
 | 
						|
	if err := m.createCollection(ctx, req); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.createIndex(ctx, req); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
 | 
						|
	if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
 | 
						|
		if !strings.Contains(err.Error(), errIndexNotFound) {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
 | 
						|
		if !strings.Contains(err.Error(), errCollectionNotFound) {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) GetType() searchstore.SearchStoreType {
 | 
						|
	return searchstore.TypeVectorStore
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
 | 
						|
	collection, err := m.config.Service.GetCollection(collectionName)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &vkSearchStore{manager: m, collection: collection}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
 | 
						|
	svc := m.config.Service
 | 
						|
 | 
						|
	collection, err := svc.GetCollection(req.CollectionName)
 | 
						|
	if err != nil {
 | 
						|
		if !strings.Contains(err.Error(), errCollectionNotFound) {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	} else if collection != nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	fields, vopts, err := m.mapFields(req.Fields)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if vopts != nil {
 | 
						|
		_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
 | 
						|
	} else {
 | 
						|
		_, err = svc.CreateCollection(req.CollectionName, fields, "")
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
 | 
						|
	svc := m.config.Service
 | 
						|
	index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
 | 
						|
	if err != nil {
 | 
						|
		if !strings.Contains(err.Error(), errIndexNotFound) {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	} else if index != nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	vectorIndex := &vikingdb.VectorIndexParams{
 | 
						|
		IndexType: string(m.config.IndexingConfig.Type),
 | 
						|
		Distance:  string(ptr.From(m.config.IndexingConfig.Distance)),
 | 
						|
		Quant:     string(ptr.From(m.config.IndexingConfig.Quant)),
 | 
						|
		HnswM:     ptr.From(m.config.IndexingConfig.HnswM),
 | 
						|
		HnswCef:   ptr.From(m.config.IndexingConfig.HnswCef),
 | 
						|
		HnswSef:   ptr.From(m.config.IndexingConfig.HnswSef),
 | 
						|
	}
 | 
						|
 | 
						|
	opts := vikingdb.NewIndexOptions().
 | 
						|
		SetVectorIndex(vectorIndex).
 | 
						|
		SetCpuQuota(m.config.IndexingConfig.CpuQuota).
 | 
						|
		SetShardCount(m.config.IndexingConfig.ShardCount)
 | 
						|
 | 
						|
	_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
 | 
						|
	var (
 | 
						|
		foundID        bool
 | 
						|
		foundCreatorID bool
 | 
						|
		dstFields      = make([]vikingdb.Field, 0, len(srcFields))
 | 
						|
		vectorizeOpts  []*vikingdb.VectorizeTuple
 | 
						|
		embConfig      = m.config.EmbeddingConfig
 | 
						|
	)
 | 
						|
 | 
						|
	for _, srcField := range srcFields {
 | 
						|
		switch srcField.Name {
 | 
						|
		case searchstore.FieldID:
 | 
						|
			foundID = true
 | 
						|
		case searchstore.FieldCreatorID:
 | 
						|
			foundCreatorID = true
 | 
						|
		default:
 | 
						|
		}
 | 
						|
 | 
						|
		if srcField.Indexing {
 | 
						|
			if srcField.Type != searchstore.FieldTypeText {
 | 
						|
				return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
 | 
						|
			}
 | 
						|
			if embConfig.UseVikingEmbedding {
 | 
						|
				vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name, false))
 | 
						|
				if embConfig.EnableHybrid {
 | 
						|
					vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name, true))
 | 
						|
				}
 | 
						|
				vectorizeOpts = append(vectorizeOpts, vt)
 | 
						|
			} else {
 | 
						|
				dstFields = append(dstFields, vikingdb.Field{
 | 
						|
					FieldName:  denseFieldName(srcField.Name),
 | 
						|
					FieldType:  vikingdb.Vector,
 | 
						|
					DefaultVal: nil,
 | 
						|
					Dim:        m.getDims(),
 | 
						|
				})
 | 
						|
			}
 | 
						|
 | 
						|
		}
 | 
						|
 | 
						|
		dstField := vikingdb.Field{
 | 
						|
			FieldName:    srcField.Name,
 | 
						|
			IsPrimaryKey: srcField.IsPrimary,
 | 
						|
		}
 | 
						|
		switch srcField.Type {
 | 
						|
		case searchstore.FieldTypeInt64:
 | 
						|
			dstField.FieldType = vikingdb.Int64
 | 
						|
		case searchstore.FieldTypeText:
 | 
						|
			dstField.FieldType = vikingdb.Text
 | 
						|
		case searchstore.FieldTypeDenseVector:
 | 
						|
			dstField.FieldType = vikingdb.Vector
 | 
						|
			dstField.Dim = m.getDims()
 | 
						|
		case searchstore.FieldTypeSparseVector:
 | 
						|
			dstField.FieldType = vikingdb.Sparse_Vector
 | 
						|
		default:
 | 
						|
			return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
 | 
						|
		}
 | 
						|
		dstFields = append(dstFields, dstField)
 | 
						|
	}
 | 
						|
 | 
						|
	if !foundID {
 | 
						|
		dstFields = append(dstFields, vikingdb.Field{
 | 
						|
			FieldName:    searchstore.FieldID,
 | 
						|
			FieldType:    vikingdb.Int64,
 | 
						|
			IsPrimaryKey: true,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	if !foundCreatorID {
 | 
						|
		dstFields = append(dstFields, vikingdb.Field{
 | 
						|
			FieldName: searchstore.FieldCreatorID,
 | 
						|
			FieldType: vikingdb.Int64,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	return dstFields, vectorizeOpts, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) newVectorizeModelConf(fieldName string, isSparse bool) *vikingdb.VectorizeModelConf {
 | 
						|
	embConfig := m.config.EmbeddingConfig
 | 
						|
	vmc := vikingdb.NewVectorizeModelConf().
 | 
						|
		SetTextField(fieldName).
 | 
						|
		SetModelName(string(embConfig.ModelName))
 | 
						|
	if !isSparse {
 | 
						|
		vmc = vmc.SetDim(m.getDims())
 | 
						|
	}
 | 
						|
	if embConfig.ModelVersion != nil {
 | 
						|
		vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
 | 
						|
	}
 | 
						|
	return vmc
 | 
						|
}
 | 
						|
 | 
						|
func (m *manager) getDims() int64 {
 | 
						|
	if m.config.EmbeddingConfig.UseVikingEmbedding {
 | 
						|
		return m.config.EmbeddingConfig.ModelName.Dimensions()
 | 
						|
	}
 | 
						|
 | 
						|
	return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
 | 
						|
}
 |