335 lines
11 KiB
Go
335 lines
11 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package milvus
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
mentity "github.com/milvus-io/milvus/client/v2/entity"
|
|
mindex "github.com/milvus-io/milvus/client/v2/index"
|
|
client "github.com/milvus-io/milvus/client/v2/milvusclient"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
|
)
|
|
|
|
type ManagerConfig struct {
|
|
Client *client.Client // required
|
|
Embedding embedding.Embedder // required
|
|
|
|
EnableHybrid *bool // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
|
DenseIndex mindex.Index // optional: default HNSW, M=30, efConstruction=360
|
|
DenseMetric mentity.MetricType // optional: default IP
|
|
SparseIndex mindex.Index // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
|
|
SparseMetric mentity.MetricType // optional: default IP
|
|
ShardNum int // optional: default 1
|
|
BatchSize int // optional: default 100
|
|
}
|
|
|
|
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
|
if config.Client == nil {
|
|
return nil, fmt.Errorf("[NewManager] milvus client not provided")
|
|
}
|
|
if config.Embedding == nil {
|
|
return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
|
|
}
|
|
|
|
enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
|
if config.EnableHybrid == nil {
|
|
config.EnableHybrid = ptr.Of(enableSparse)
|
|
} else if !enableSparse && ptr.From(config.EnableHybrid) {
|
|
logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
|
|
config.EnableHybrid = ptr.Of(false)
|
|
}
|
|
if config.DenseMetric == "" {
|
|
config.DenseMetric = mentity.IP
|
|
}
|
|
if config.DenseIndex == nil {
|
|
config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
|
|
}
|
|
if config.SparseMetric == "" {
|
|
config.SparseMetric = mentity.IP
|
|
}
|
|
if config.SparseIndex == nil {
|
|
config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
|
|
}
|
|
if config.ShardNum == 0 {
|
|
config.ShardNum = 1
|
|
}
|
|
if config.BatchSize == 0 {
|
|
config.BatchSize = 100
|
|
}
|
|
|
|
return &milvusManager{config: config}, nil
|
|
}
|
|
|
|
type milvusManager struct {
|
|
config *ManagerConfig
|
|
}
|
|
|
|
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
|
if err := m.createCollection(ctx, req); err != nil {
|
|
return fmt.Errorf("[Create] create collection failed, %w", err)
|
|
}
|
|
|
|
if err := m.createIndexes(ctx, req); err != nil {
|
|
return fmt.Errorf("[Create] create indexes failed, %w", err)
|
|
}
|
|
|
|
if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
|
|
return fmt.Errorf("[Create] load collection failed, %w", err)
|
|
} else if !exists {
|
|
return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
|
|
return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
|
|
}
|
|
|
|
func (m *milvusManager) GetType() searchstore.SearchStoreType {
|
|
return searchstore.TypeVectorStore
|
|
}
|
|
|
|
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
|
|
if exists, err := m.loadCollection(ctx, collectionName); err != nil {
|
|
return nil, err
|
|
} else if !exists {
|
|
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
|
errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
|
|
}
|
|
|
|
return &milvusSearchStore{
|
|
config: m.config,
|
|
collectionName: collectionName,
|
|
}, nil
|
|
}
|
|
|
|
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
|
if req.CollectionName == "" || len(req.Fields) == 0 {
|
|
return fmt.Errorf("[createCollection] invalid request params")
|
|
}
|
|
|
|
cli := m.config.Client
|
|
collectionName := req.CollectionName
|
|
has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
|
|
if err != nil {
|
|
return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
|
|
}
|
|
if has {
|
|
return nil
|
|
}
|
|
|
|
fields, err := m.convertFields(req.Fields)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
|
|
CollectionName: collectionName,
|
|
Description: fmt.Sprintf("created by coze"),
|
|
AutoID: false,
|
|
Fields: fields,
|
|
EnableDynamicField: false,
|
|
}).WithShardNum(int32(m.config.ShardNum))
|
|
|
|
for k, v := range req.CollectionMeta {
|
|
opt.WithProperty(k, v)
|
|
}
|
|
|
|
if err = cli.CreateCollection(ctx, opt); err != nil {
|
|
return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
|
|
collectionName := req.CollectionName
|
|
indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
|
|
if err != nil {
|
|
if !strings.Contains(err.Error(), "index not found") {
|
|
return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
|
|
}
|
|
}
|
|
|
|
createdIndexes := sets.FromSlice(indexes)
|
|
|
|
var ops []func() error
|
|
for i := range req.Fields {
|
|
f := req.Fields[i]
|
|
if !f.Indexing {
|
|
continue
|
|
}
|
|
|
|
ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
|
|
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
|
|
ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
|
|
}
|
|
}
|
|
|
|
for _, op := range ops {
|
|
if err := op(); err != nil {
|
|
return fmt.Errorf("[createIndexes] failed, %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
|
|
return func() error {
|
|
if _, found := createdIndexes[indexName]; found {
|
|
logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
|
|
collectionName, fieldName, indexName, idx.IndexType())
|
|
|
|
return nil
|
|
}
|
|
|
|
cli := m.config.Client
|
|
|
|
task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
|
|
if err != nil {
|
|
return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
|
|
}
|
|
|
|
if err = task.Await(ctx); err != nil {
|
|
return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
|
|
}
|
|
|
|
logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
|
|
collectionName, fieldName, indexName, idx.IndexType())
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
|
|
cli := m.config.Client
|
|
|
|
stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
|
|
if err != nil {
|
|
return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
|
|
}
|
|
|
|
switch stat.State {
|
|
case mentity.LoadStateNotLoad:
|
|
task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
|
|
if err != nil {
|
|
return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
|
|
}
|
|
if err = task.Await(ctx); err != nil {
|
|
return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
|
|
}
|
|
return true, nil
|
|
case mentity.LoadStateLoaded:
|
|
return true, nil
|
|
case mentity.LoadStateLoading:
|
|
return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
|
|
case mentity.LoadStateUnloading:
|
|
return false, nil
|
|
default:
|
|
return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
|
|
}
|
|
}
|
|
|
|
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
|
|
var foundID, foundCreatorID bool
|
|
resp := make([]*mentity.Field, 0, len(fields))
|
|
for _, f := range fields {
|
|
switch f.Name {
|
|
case searchstore.FieldID:
|
|
foundID = true
|
|
case searchstore.FieldCreatorID:
|
|
foundCreatorID = true
|
|
default:
|
|
}
|
|
|
|
if f.Indexing {
|
|
if f.Type != searchstore.FieldTypeText {
|
|
return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
|
|
}
|
|
// indexing 时只有 content 存储原文
|
|
if f.Name == searchstore.FieldTextContent {
|
|
resp = append(resp, mentity.NewField().
|
|
WithName(f.Name).
|
|
WithDescription(f.Description).
|
|
WithIsPrimaryKey(f.IsPrimary).
|
|
WithNullable(f.Nullable).
|
|
WithDataType(mentity.FieldTypeVarChar).
|
|
WithMaxLength(65535))
|
|
}
|
|
resp = append(resp, mentity.NewField().
|
|
WithName(denseFieldName(f.Name)).
|
|
WithDataType(mentity.FieldTypeFloatVector).
|
|
WithDim(m.config.Embedding.Dimensions()))
|
|
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
|
|
resp = append(resp, mentity.NewField().
|
|
WithName(sparseFieldName(f.Name)).
|
|
WithDataType(mentity.FieldTypeSparseVector))
|
|
}
|
|
} else {
|
|
mf := mentity.NewField().
|
|
WithName(f.Name).
|
|
WithDescription(f.Description).
|
|
WithIsPrimaryKey(f.IsPrimary).
|
|
WithNullable(f.Nullable)
|
|
typ, err := convertFieldType(f.Type)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mf.WithDataType(typ)
|
|
if typ == mentity.FieldTypeVarChar {
|
|
mf.WithMaxLength(65535)
|
|
} else if typ == mentity.FieldTypeFloatVector {
|
|
mf.WithDim(m.config.Embedding.Dimensions())
|
|
}
|
|
resp = append(resp, mf)
|
|
}
|
|
}
|
|
|
|
if !foundID {
|
|
resp = append(resp, mentity.NewField().
|
|
WithName(searchstore.FieldID).
|
|
WithDataType(mentity.FieldTypeInt64).
|
|
WithIsPrimaryKey(true).
|
|
WithNullable(false))
|
|
}
|
|
|
|
if !foundCreatorID {
|
|
resp = append(resp, mentity.NewField().
|
|
WithName(searchstore.FieldCreatorID).
|
|
WithDataType(mentity.FieldTypeInt64).
|
|
WithNullable(false))
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (m *milvusManager) GetEmbedding() embedding.Embedder {
|
|
return m.config.Embedding
|
|
}
|