335 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			335 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package milvus
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	mentity "github.com/milvus-io/milvus/client/v2/entity"
 | 
						|
	mindex "github.com/milvus-io/milvus/client/v2/index"
 | 
						|
	client "github.com/milvus-io/milvus/client/v2/milvusclient"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/errorx"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/types/errno"
 | 
						|
)
 | 
						|
 | 
						|
type ManagerConfig struct {
 | 
						|
	Client    *client.Client     // required
 | 
						|
	Embedding embedding.Embedder // required
 | 
						|
 | 
						|
	EnableHybrid *bool              // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
 | 
						|
	DenseIndex   mindex.Index       // optional: default HNSW, M=30, efConstruction=360
 | 
						|
	DenseMetric  mentity.MetricType // optional: default IP
 | 
						|
	SparseIndex  mindex.Index       // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
 | 
						|
	SparseMetric mentity.MetricType // optional: default IP
 | 
						|
	ShardNum     int                // optional: default 1
 | 
						|
	BatchSize    int                // optional: default 100
 | 
						|
}
 | 
						|
 | 
						|
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
 | 
						|
	if config.Client == nil {
 | 
						|
		return nil, fmt.Errorf("[NewManager] milvus client not provided")
 | 
						|
	}
 | 
						|
	if config.Embedding == nil {
 | 
						|
		return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
 | 
						|
	}
 | 
						|
 | 
						|
	enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
 | 
						|
	if config.EnableHybrid == nil {
 | 
						|
		config.EnableHybrid = ptr.Of(enableSparse)
 | 
						|
	} else if !enableSparse && ptr.From(config.EnableHybrid) {
 | 
						|
		logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
 | 
						|
		config.EnableHybrid = ptr.Of(false)
 | 
						|
	}
 | 
						|
	if config.DenseMetric == "" {
 | 
						|
		config.DenseMetric = mentity.IP
 | 
						|
	}
 | 
						|
	if config.DenseIndex == nil {
 | 
						|
		config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
 | 
						|
	}
 | 
						|
	if config.SparseMetric == "" {
 | 
						|
		config.SparseMetric = mentity.IP
 | 
						|
	}
 | 
						|
	if config.SparseIndex == nil {
 | 
						|
		config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
 | 
						|
	}
 | 
						|
	if config.ShardNum == 0 {
 | 
						|
		config.ShardNum = 1
 | 
						|
	}
 | 
						|
	if config.BatchSize == 0 {
 | 
						|
		config.BatchSize = 100
 | 
						|
	}
 | 
						|
 | 
						|
	return &milvusManager{config: config}, nil
 | 
						|
}
 | 
						|
 | 
						|
type milvusManager struct {
 | 
						|
	config *ManagerConfig
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
 | 
						|
	if err := m.createCollection(ctx, req); err != nil {
 | 
						|
		return fmt.Errorf("[Create] create collection failed, %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.createIndexes(ctx, req); err != nil {
 | 
						|
		return fmt.Errorf("[Create] create indexes failed, %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
 | 
						|
		return fmt.Errorf("[Create] load collection failed, %w", err)
 | 
						|
	} else if !exists {
 | 
						|
		return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
 | 
						|
	return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) GetType() searchstore.SearchStoreType {
 | 
						|
	return searchstore.TypeVectorStore
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
 | 
						|
	if exists, err := m.loadCollection(ctx, collectionName); err != nil {
 | 
						|
		return nil, err
 | 
						|
	} else if !exists {
 | 
						|
		return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
 | 
						|
			errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
 | 
						|
	}
 | 
						|
 | 
						|
	return &milvusSearchStore{
 | 
						|
		config:         m.config,
 | 
						|
		collectionName: collectionName,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
 | 
						|
	if req.CollectionName == "" || len(req.Fields) == 0 {
 | 
						|
		return fmt.Errorf("[createCollection] invalid request params")
 | 
						|
	}
 | 
						|
 | 
						|
	cli := m.config.Client
 | 
						|
	collectionName := req.CollectionName
 | 
						|
	has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
 | 
						|
	}
 | 
						|
	if has {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	fields, err := m.convertFields(req.Fields)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
 | 
						|
		CollectionName:     collectionName,
 | 
						|
		Description:        fmt.Sprintf("created by coze"),
 | 
						|
		AutoID:             false,
 | 
						|
		Fields:             fields,
 | 
						|
		EnableDynamicField: false,
 | 
						|
	}).WithShardNum(int32(m.config.ShardNum))
 | 
						|
 | 
						|
	for k, v := range req.CollectionMeta {
 | 
						|
		opt.WithProperty(k, v)
 | 
						|
	}
 | 
						|
 | 
						|
	if err = cli.CreateCollection(ctx, opt); err != nil {
 | 
						|
		return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
 | 
						|
	collectionName := req.CollectionName
 | 
						|
	indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
 | 
						|
	if err != nil {
 | 
						|
		if !strings.Contains(err.Error(), "index not found") {
 | 
						|
			return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	createdIndexes := sets.FromSlice(indexes)
 | 
						|
 | 
						|
	var ops []func() error
 | 
						|
	for i := range req.Fields {
 | 
						|
		f := req.Fields[i]
 | 
						|
		if !f.Indexing {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
 | 
						|
		if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
 | 
						|
			ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for _, op := range ops {
 | 
						|
		if err := op(); err != nil {
 | 
						|
			return fmt.Errorf("[createIndexes] failed, %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
 | 
						|
	return func() error {
 | 
						|
		if _, found := createdIndexes[indexName]; found {
 | 
						|
			logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
 | 
						|
				collectionName, fieldName, indexName, idx.IndexType())
 | 
						|
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
 | 
						|
		cli := m.config.Client
 | 
						|
 | 
						|
		task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		if err = task.Await(ctx); err != nil {
 | 
						|
			return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
 | 
						|
			collectionName, fieldName, indexName, idx.IndexType())
 | 
						|
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
 | 
						|
	cli := m.config.Client
 | 
						|
 | 
						|
	stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
 | 
						|
	if err != nil {
 | 
						|
		return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	switch stat.State {
 | 
						|
	case mentity.LoadStateNotLoad:
 | 
						|
		task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
 | 
						|
		if err != nil {
 | 
						|
			return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
 | 
						|
		}
 | 
						|
		if err = task.Await(ctx); err != nil {
 | 
						|
			return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
 | 
						|
		}
 | 
						|
		return true, nil
 | 
						|
	case mentity.LoadStateLoaded:
 | 
						|
		return true, nil
 | 
						|
	case mentity.LoadStateLoading:
 | 
						|
		return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
 | 
						|
	case mentity.LoadStateUnloading:
 | 
						|
		return false, nil
 | 
						|
	default:
 | 
						|
		return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
 | 
						|
	var foundID, foundCreatorID bool
 | 
						|
	resp := make([]*mentity.Field, 0, len(fields))
 | 
						|
	for _, f := range fields {
 | 
						|
		switch f.Name {
 | 
						|
		case searchstore.FieldID:
 | 
						|
			foundID = true
 | 
						|
		case searchstore.FieldCreatorID:
 | 
						|
			foundCreatorID = true
 | 
						|
		default:
 | 
						|
		}
 | 
						|
 | 
						|
		if f.Indexing {
 | 
						|
			if f.Type != searchstore.FieldTypeText {
 | 
						|
				return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
 | 
						|
			}
 | 
						|
			// indexing 时只有 content 存储原文
 | 
						|
			if f.Name == searchstore.FieldTextContent {
 | 
						|
				resp = append(resp, mentity.NewField().
 | 
						|
					WithName(f.Name).
 | 
						|
					WithDescription(f.Description).
 | 
						|
					WithIsPrimaryKey(f.IsPrimary).
 | 
						|
					WithNullable(f.Nullable).
 | 
						|
					WithDataType(mentity.FieldTypeVarChar).
 | 
						|
					WithMaxLength(65535))
 | 
						|
			}
 | 
						|
			resp = append(resp, mentity.NewField().
 | 
						|
				WithName(denseFieldName(f.Name)).
 | 
						|
				WithDataType(mentity.FieldTypeFloatVector).
 | 
						|
				WithDim(m.config.Embedding.Dimensions()))
 | 
						|
			if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
 | 
						|
				resp = append(resp, mentity.NewField().
 | 
						|
					WithName(sparseFieldName(f.Name)).
 | 
						|
					WithDataType(mentity.FieldTypeSparseVector))
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			mf := mentity.NewField().
 | 
						|
				WithName(f.Name).
 | 
						|
				WithDescription(f.Description).
 | 
						|
				WithIsPrimaryKey(f.IsPrimary).
 | 
						|
				WithNullable(f.Nullable)
 | 
						|
			typ, err := convertFieldType(f.Type)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			mf.WithDataType(typ)
 | 
						|
			if typ == mentity.FieldTypeVarChar {
 | 
						|
				mf.WithMaxLength(65535)
 | 
						|
			} else if typ == mentity.FieldTypeFloatVector {
 | 
						|
				mf.WithDim(m.config.Embedding.Dimensions())
 | 
						|
			}
 | 
						|
			resp = append(resp, mf)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if !foundID {
 | 
						|
		resp = append(resp, mentity.NewField().
 | 
						|
			WithName(searchstore.FieldID).
 | 
						|
			WithDataType(mentity.FieldTypeInt64).
 | 
						|
			WithIsPrimaryKey(true).
 | 
						|
			WithNullable(false))
 | 
						|
	}
 | 
						|
 | 
						|
	if !foundCreatorID {
 | 
						|
		resp = append(resp, mentity.NewField().
 | 
						|
			WithName(searchstore.FieldCreatorID).
 | 
						|
			WithDataType(mentity.FieldTypeInt64).
 | 
						|
			WithNullable(false))
 | 
						|
	}
 | 
						|
 | 
						|
	return resp, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *milvusManager) GetEmbedding() embedding.Embedder {
 | 
						|
	return m.config.Embedding
 | 
						|
}
 |