feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
const (
batchSize = 100
topK = 4
)

View File

@@ -0,0 +1,119 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"fmt"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
func denseFieldName(name string) string {
return fmt.Sprintf("dense_%s", name)
}
func denseIndexName(name string) string {
return fmt.Sprintf("index_dense_%s", name)
}
func sparseFieldName(name string) string {
return fmt.Sprintf("sparse_%s", name)
}
func sparseIndexName(name string) string {
return fmt.Sprintf("index_sparse_%s", name)
}
func convertFieldType(typ searchstore.FieldType) (entity.FieldType, error) {
switch typ {
case searchstore.FieldTypeInt64:
return entity.FieldTypeInt64, nil
case searchstore.FieldTypeText:
return entity.FieldTypeVarChar, nil
case searchstore.FieldTypeDenseVector:
return entity.FieldTypeFloatVector, nil
case searchstore.FieldTypeSparseVector:
return entity.FieldTypeSparseVector, nil
default:
return entity.FieldTypeNone, fmt.Errorf("[convertFieldType] unknown field type: %v", typ)
}
}
func convertDense(dense [][]float64) [][]float32 {
return slices.Transform(dense, func(a []float64) []float32 {
r := make([]float32, len(a))
for i := 0; i < len(a); i++ {
r[i] = float32(a[i])
}
return r
})
}
func convertMilvusDenseVector(dense [][]float64) []entity.Vector {
return slices.Transform(dense, func(a []float64) entity.Vector {
r := make([]float32, len(a))
for i := 0; i < len(a); i++ {
r[i] = float32(a[i])
}
return entity.FloatVector(r)
})
}
func convertSparse(sparse []map[int]float64) ([]entity.SparseEmbedding, error) {
r := make([]entity.SparseEmbedding, 0, len(sparse))
for _, s := range sparse {
ks := make([]uint32, 0, len(s))
vs := make([]float32, 0, len(s))
for k, v := range s {
ks = append(ks, uint32(k))
vs = append(vs, float32(v))
}
se, err := entity.NewSliceSparseEmbedding(ks, vs)
if err != nil {
return nil, err
}
r = append(r, se)
}
return r, nil
}
func convertMilvusSparseVector(sparse []map[int]float64) ([]entity.Vector, error) {
r := make([]entity.Vector, 0, len(sparse))
for _, s := range sparse {
ks := make([]uint32, 0, len(s))
vs := make([]float32, 0, len(s))
for k, v := range s {
ks = append(ks, uint32(k))
vs = append(vs, float32(v))
}
se, err := entity.NewSliceSparseEmbedding(ks, vs)
if err != nil {
return nil, err
}
r = append(r, se)
}
return r, nil
}

View File

@@ -0,0 +1,334 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"context"
"fmt"
"strings"
mentity "github.com/milvus-io/milvus/client/v2/entity"
mindex "github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type ManagerConfig struct {
Client *client.Client // required
Embedding embedding.Embedder // required
EnableHybrid *bool // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
DenseIndex mindex.Index // optional: default HNSW, M=30, efConstruction=360
DenseMetric mentity.MetricType // optional: default IP
SparseIndex mindex.Index // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
SparseMetric mentity.MetricType // optional: default IP
ShardNum int // optional: default 1
BatchSize int // optional: default 100
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Client == nil {
return nil, fmt.Errorf("[NewManager] milvus client not provided")
}
if config.Embedding == nil {
return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
}
enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
if config.EnableHybrid == nil {
config.EnableHybrid = ptr.Of(enableSparse)
} else if !enableSparse && ptr.From(config.EnableHybrid) {
logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
config.EnableHybrid = ptr.Of(false)
}
if config.DenseMetric == "" {
config.DenseMetric = mentity.IP
}
if config.DenseIndex == nil {
config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
}
if config.SparseMetric == "" {
config.SparseMetric = mentity.IP
}
if config.SparseIndex == nil {
config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
}
if config.ShardNum == 0 {
config.ShardNum = 1
}
if config.BatchSize == 0 {
config.BatchSize = 100
}
return &milvusManager{config: config}, nil
}
type milvusManager struct {
config *ManagerConfig
}
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return fmt.Errorf("[Create] create collection failed, %w", err)
}
if err := m.createIndexes(ctx, req); err != nil {
return fmt.Errorf("[Create] create indexes failed, %w", err)
}
if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
return fmt.Errorf("[Create] load collection failed, %w", err)
} else if !exists {
return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
}
return nil
}
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
}
func (m *milvusManager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
if exists, err := m.loadCollection(ctx, collectionName); err != nil {
return nil, err
} else if !exists {
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
}
return &milvusSearchStore{
config: m.config,
collectionName: collectionName,
}, nil
}
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
if req.CollectionName == "" || len(req.Fields) == 0 {
return fmt.Errorf("[createCollection] invalid request params")
}
cli := m.config.Client
collectionName := req.CollectionName
has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
if err != nil {
return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
}
if has {
return nil
}
fields, err := m.convertFields(req.Fields)
if err != nil {
return err
}
opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
CollectionName: collectionName,
Description: fmt.Sprintf("created by coze"),
AutoID: false,
Fields: fields,
EnableDynamicField: false,
}).WithShardNum(int32(m.config.ShardNum))
for k, v := range req.CollectionMeta {
opt.WithProperty(k, v)
}
if err = cli.CreateCollection(ctx, opt); err != nil {
return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
}
return nil
}
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
collectionName := req.CollectionName
indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
if err != nil {
if !strings.Contains(err.Error(), "index not found") {
return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
}
}
createdIndexes := sets.FromSlice(indexes)
var ops []func() error
for i := range req.Fields {
f := req.Fields[i]
if !f.Indexing {
continue
}
ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
}
}
for _, op := range ops {
if err := op(); err != nil {
return fmt.Errorf("[createIndexes] failed, %w", err)
}
}
return nil
}
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
return func() error {
if _, found := createdIndexes[indexName]; found {
logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
collectionName, fieldName, indexName, idx.IndexType())
return nil
}
cli := m.config.Client
task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
if err != nil {
return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
}
if err = task.Await(ctx); err != nil {
return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
}
logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
collectionName, fieldName, indexName, idx.IndexType())
return nil
}
}
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
cli := m.config.Client
stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
if err != nil {
return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
}
switch stat.State {
case mentity.LoadStateNotLoad:
task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
if err != nil {
return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
}
if err = task.Await(ctx); err != nil {
return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
}
return true, nil
case mentity.LoadStateLoaded:
return true, nil
case mentity.LoadStateLoading:
return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
case mentity.LoadStateUnloading:
return false, nil
default:
return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
}
}
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
var foundID, foundCreatorID bool
resp := make([]*mentity.Field, 0, len(fields))
for _, f := range fields {
switch f.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if f.Indexing {
if f.Type != searchstore.FieldTypeText {
return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
}
// indexing 时只有 content 存储原文
if f.Name == searchstore.FieldTextContent {
resp = append(resp, mentity.NewField().
WithName(f.Name).
WithDescription(f.Description).
WithIsPrimaryKey(f.IsPrimary).
WithNullable(f.Nullable).
WithDataType(mentity.FieldTypeVarChar).
WithMaxLength(65535))
}
resp = append(resp, mentity.NewField().
WithName(denseFieldName(f.Name)).
WithDataType(mentity.FieldTypeFloatVector).
WithDim(m.config.Embedding.Dimensions()))
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
resp = append(resp, mentity.NewField().
WithName(sparseFieldName(f.Name)).
WithDataType(mentity.FieldTypeSparseVector))
}
} else {
mf := mentity.NewField().
WithName(f.Name).
WithDescription(f.Description).
WithIsPrimaryKey(f.IsPrimary).
WithNullable(f.Nullable)
typ, err := convertFieldType(f.Type)
if err != nil {
return nil, err
}
mf.WithDataType(typ)
if typ == mentity.FieldTypeVarChar {
mf.WithMaxLength(65535)
} else if typ == mentity.FieldTypeFloatVector {
mf.WithDim(m.config.Embedding.Dimensions())
}
resp = append(resp, mf)
}
}
if !foundID {
resp = append(resp, mentity.NewField().
WithName(searchstore.FieldID).
WithDataType(mentity.FieldTypeInt64).
WithIsPrimaryKey(true).
WithNullable(false))
}
if !foundCreatorID {
resp = append(resp, mentity.NewField().
WithName(searchstore.FieldCreatorID).
WithDataType(mentity.FieldTypeInt64).
WithNullable(false))
}
return resp, nil
}
func (m *milvusManager) GetEmbedding() embedding.Embedder {
return m.config.Embedding
}

View File

@@ -0,0 +1,600 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"context"
"encoding/json"
"fmt"
"math"
"reflect"
"sort"
"strconv"
"strings"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/column"
mentity "github.com/milvus-io/milvus/client/v2/entity"
mindex "github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/slongfield/pyfmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type milvusSearchStore struct {
config *ManagerConfig
collectionName string
}
func (m *milvusSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
if len(docs) == 0 {
return nil, nil
}
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
indexingFields := make(sets.Set[string])
for _, field := range implSpecOptions.IndexingFields {
indexingFields[field] = struct{}{}
}
if implSpecOptions.Partition != nil {
partition := *implSpecOptions.Partition
hasPartition, err := m.config.Client.HasPartition(ctx, client.NewHasPartitionOption(m.collectionName, partition))
if err != nil {
return nil, fmt.Errorf("[Store] HasPartition failed, %w", err)
}
if !hasPartition {
if err = m.config.Client.CreatePartition(ctx, client.NewCreatePartitionOption(m.collectionName, partition)); err != nil {
return nil, fmt.Errorf("[Store] CreatePartition failed, %w", err)
}
}
}
for _, part := range slices.Chunks(docs, batchSize) {
columns, err := m.documents2Columns(ctx, part, indexingFields)
if err != nil {
return nil, err
}
createReq := client.NewColumnBasedInsertOption(m.collectionName, columns...)
if implSpecOptions.Partition != nil {
createReq.WithPartition(*implSpecOptions.Partition)
}
result, err := m.config.Client.Upsert(ctx, createReq)
if err != nil {
return nil, fmt.Errorf("[Store] upsert failed, %w", err)
}
partIDs := result.IDs
for i := 0; i < partIDs.Len(); i++ {
var sid string
if partIDs.Type() == mentity.FieldTypeInt64 {
id, err := partIDs.GetAsInt64(i)
if err != nil {
return nil, err
}
sid = strconv.FormatInt(id, 10)
} else {
sid, err = partIDs.GetAsString(i)
if err != nil {
return nil, err
}
}
ids = append(ids, sid)
}
if implSpecOptions.ProgressBar != nil {
if err = implSpecOptions.ProgressBar.AddN(len(part)); err != nil {
return nil, err
}
}
}
return ids, nil
}
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
cli := m.config.Client
emb := m.config.Embedding
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
if err != nil {
return nil, err
}
var (
dense [][]float64
sparse []map[int]float64
expr string
result []client.ResultSet
fields = desc.Schema.Fields
outputFields []string
enableSparse = m.enableSparse(fields)
)
if options.DSLInfo != nil {
expr, err = m.dsl2Expr(options.DSLInfo)
if err != nil {
return nil, err
}
}
if enableSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] EmbedStringsHybrid failed, %w", err)
}
} else {
dense, err = emb.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] EmbedStrings failed, %w", err)
}
}
dv := convertMilvusDenseVector(dense)
sv, err := convertMilvusSparseVector(sparse)
if err != nil {
return nil, err
}
for _, field := range fields {
outputFields = append(outputFields, field.Name)
}
var scoreNormType *mindex.MetricType
if enableSparse {
var annRequests []*client.AnnRequest
for _, field := range fields {
var (
vector []mentity.Vector
metricsType mindex.MetricType
)
if field.DataType == mentity.FieldTypeFloatVector {
vector = dv
metricsType, err = m.getIndexMetricsType(ctx, denseIndexName(field.Name))
} else if field.DataType == mentity.FieldTypeSparseVector {
vector = sv
metricsType, err = m.getIndexMetricsType(ctx, sparseIndexName(field.Name))
}
if err != nil {
return nil, err
}
annRequests = append(annRequests,
client.NewAnnRequest(field.Name, ptr.From(options.TopK), vector...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType)).
WithFilter(expr),
)
}
searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
WithPartitons(implSpecOptions.Partitions...).
WithReranker(client.NewRRFReranker()).
WithOutputFields(outputFields...)
result, err = cli.HybridSearch(ctx, searchOption)
if err != nil {
return nil, fmt.Errorf("[Retrieve] HybridSearch failed, %w", err)
}
} else {
indexes, err := cli.ListIndexes(ctx, client.NewListIndexOption(m.collectionName))
if err != nil {
return nil, fmt.Errorf("[Retrieve] ListIndexes failed, %w", err)
}
if len(indexes) != 1 {
return nil, fmt.Errorf("[Retrieve] restrict single index ann search, but got %d, collection=%s",
len(indexes), m.collectionName)
}
metricsType, err := m.getIndexMetricsType(ctx, indexes[0])
if err != nil {
return nil, err
}
scoreNormType = &metricsType
searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
WithPartitions(implSpecOptions.Partitions...).
WithFilter(expr).
WithOutputFields(outputFields...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType))
result, err = cli.Search(ctx, searchOption)
if err != nil {
return nil, fmt.Errorf("[Retrieve] Search failed, %w", err)
}
}
docs, err := m.resultSet2Document(result, scoreNormType)
if err != nil {
return nil, fmt.Errorf("[Retrieve] resultSet2Document failed, %w", err)
}
return docs, nil
}
func (m *milvusSearchStore) Delete(ctx context.Context, ids []string) error {
int64IDs := make([]int64, 0, len(ids))
for _, sid := range ids {
id, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
return err
}
int64IDs = append(int64IDs, id)
}
_, err := m.config.Client.Delete(ctx,
client.NewDeleteOption(m.collectionName).WithInt64IDs(searchstore.FieldID, int64IDs))
return err
}
func (m *milvusSearchStore) documents2Columns(ctx context.Context, docs []*schema.Document, indexingFields sets.Set[string]) (
cols []column.Column, err error) {
var (
ids []int64
contents []string
creatorIDs []int64
emptyContents = true
)
colMapping := map[string]any{}
colTypeMapping := map[string]searchstore.FieldType{
searchstore.FieldID: searchstore.FieldTypeInt64,
searchstore.FieldCreatorID: searchstore.FieldTypeInt64,
searchstore.FieldTextContent: searchstore.FieldTypeText,
}
for _, doc := range docs {
if doc.MetaData == nil {
return nil, fmt.Errorf("[documents2Columns] meta data is nil")
}
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
return nil, fmt.Errorf("[documents2Columns] parse id failed, %w", err)
}
ids = append(ids, id)
contents = append(contents, doc.Content)
if doc.Content != "" {
emptyContents = false
}
creatorID, err := document.GetDocumentCreatorID(doc)
if err != nil {
return nil, fmt.Errorf("[documents2Columns] creator_id not found or type invalid., %w", err)
}
creatorIDs = append(creatorIDs, creatorID)
ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any)
if !ok {
continue
}
for field := range ext {
val := ext[field]
container := colMapping[field]
switch t := val.(type) {
case uint, uint8, uint16, uint32, uint64, uintptr:
var c []int64
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeInt64
} else {
c, ok = container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, int64(reflect.ValueOf(t).Uint()))
colMapping[field] = c
case int, int8, int16, int32, int64:
var c []int64
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeInt64
} else {
c, ok = container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, reflect.ValueOf(t).Int())
colMapping[field] = c
case string:
var c []string
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeText
} else {
c, ok = container.([]string)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
case []float64:
var c [][]float64
if container == nil {
container = c
colTypeMapping[field] = searchstore.FieldTypeDenseVector
} else {
c, ok = container.([][]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
case map[int]float64:
var c []map[int]float64
if container == nil {
container = c
colTypeMapping[field] = searchstore.FieldTypeSparseVector
} else {
c, ok = container.([]map[int]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
default:
return nil, fmt.Errorf("[documents2Columns] val type not support, val=%v", val)
}
}
}
colMapping[searchstore.FieldID] = ids
colMapping[searchstore.FieldCreatorID] = creatorIDs
colMapping[searchstore.FieldTextContent] = contents
for fieldName, container := range colMapping {
colType := colTypeMapping[fieldName]
switch colType {
case searchstore.FieldTypeInt64:
c, ok := container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
cols = append(cols, column.NewColumnInt64(fieldName, c))
case searchstore.FieldTypeText:
c, ok := container.([]string)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not string")
}
if _, indexing := indexingFields[fieldName]; indexing {
if fieldName == searchstore.FieldTextContent && !emptyContents {
cols = append(cols, column.NewColumnVarChar(fieldName, c))
}
var (
emb = m.config.Embedding
dense [][]float64
sparse []map[int]float64
)
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, c)
} else {
dense, err = emb.EmbedStrings(ctx, c)
}
if err != nil {
return nil, fmt.Errorf("[slices2Columns] embed failed, %w", err)
}
cols = append(cols, column.NewColumnFloatVector(denseFieldName(fieldName), int(emb.Dimensions()), convertDense(dense)))
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
s, err := convertSparse(sparse)
if err != nil {
return nil, err
}
cols = append(cols, column.NewColumnSparseVectors(sparseFieldName(fieldName), s))
}
} else {
cols = append(cols, column.NewColumnVarChar(fieldName, c))
}
case searchstore.FieldTypeDenseVector:
c, ok := container.([][]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not []float64")
}
cols = append(cols, column.NewColumnFloatVector(fieldName, int(m.config.Embedding.Dimensions()), convertDense(c)))
case searchstore.FieldTypeSparseVector:
c, ok := container.([]map[int]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not map[int]float64")
}
sparse, err := convertSparse(c)
if err != nil {
return nil, err
}
cols = append(cols, column.NewColumnSparseVectors(fieldName, sparse))
default:
return nil, fmt.Errorf("[documents2Columns] column type not support, type=%d", colType)
}
}
return cols, nil
}
func (m *milvusSearchStore) resultSet2Document(result []client.ResultSet, metricsType *mindex.MetricType) (docs []*schema.Document, err error) {
docs = make([]*schema.Document, 0, len(result))
minScore := math.MaxFloat64
maxScore := 0.0
for _, r := range result {
for i := 0; i < r.ResultCount; i++ {
ext := make(map[string]any)
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
score := float64(r.Scores[i])
minScore = min(minScore, score)
maxScore = max(maxScore, score)
doc.WithScore(score)
for _, field := range r.Fields {
switch field.Name() {
case searchstore.FieldID:
id, err := field.GetAsInt64(i)
if err != nil {
return nil, err
}
doc.ID = strconv.FormatInt(id, 10)
case searchstore.FieldTextContent:
doc.Content, err = field.GetAsString(i)
case searchstore.FieldCreatorID:
doc.MetaData[document.MetaDataKeyCreatorID], err = field.GetAsInt64(i)
default:
ext[field.Name()], err = field.Get(i)
}
if err != nil {
return nil, err
}
}
docs = append(docs, doc)
}
}
sort.Slice(docs, func(i, j int) bool {
return docs[i].Score() > docs[j].Score()
})
// norm score
if (m.config.EnableHybrid != nil && *m.config.EnableHybrid) || metricsType == nil {
return docs, nil
}
switch *metricsType {
case mentity.L2:
base := maxScore - minScore
for i := range docs {
if base == 0 {
docs[i].WithScore(1.0)
} else {
docs[i].WithScore(1.0 - (docs[i].Score()-minScore)/base)
}
}
docs = slices.Reverse(docs)
case mentity.IP, mentity.COSINE:
for i := range docs {
docs[i].WithScore((docs[i].Score() + 1) / 2)
}
default:
}
return docs, nil
}
func (m *milvusSearchStore) enableSparse(fields []*mentity.Field) bool {
found := false
for _, field := range fields {
if field.DataType == mentity.FieldTypeSparseVector {
found = true
break
}
}
return found && *m.config.EnableHybrid && m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
}
func (m *milvusSearchStore) dsl2Expr(src map[string]interface{}) (string, error) {
if src == nil {
return "", nil
}
dsl, err := searchstore.LoadDSL(src)
if err != nil {
return "", err
}
var travDSL func(dsl *searchstore.DSL) (string, error)
travDSL = func(dsl *searchstore.DSL) (string, error) {
kv := map[string]interface{}{
"field": dsl.Field,
"val": dsl.Value,
}
switch dsl.Op {
case searchstore.OpEq:
return pyfmt.Fmt("{field} == {val}", kv)
case searchstore.OpNe:
return pyfmt.Fmt("{field} != {val}", kv)
case searchstore.OpLike:
return pyfmt.Fmt("{field} LIKE {val}", kv)
case searchstore.OpIn:
b, err := json.Marshal(dsl.Value)
if err != nil {
return "", err
}
kv["val"] = string(b)
return pyfmt.Fmt("{field} IN {val}", kv)
case searchstore.OpAnd, searchstore.OpOr:
sub, ok := dsl.Value.([]*searchstore.DSL)
if !ok {
return "", fmt.Errorf("[dsl2Expr] invalid sub dsl")
}
var items []string
for _, s := range sub {
str, err := travDSL(s)
if err != nil {
return "", fmt.Errorf("[dsl2Expr] parse sub failed, %w", err)
}
items = append(items, str)
}
if dsl.Op == searchstore.OpAnd {
return strings.Join(items, " AND "), nil
} else {
return strings.Join(items, " OR "), nil
}
default:
return "", fmt.Errorf("[dsl2Expr] unknown op type=%s", dsl.Op)
}
}
return travDSL(dsl)
}
func (m *milvusSearchStore) getIndexMetricsType(ctx context.Context, indexName string) (mindex.MetricType, error) {
index, err := m.config.Client.DescribeIndex(ctx, client.NewDescribeIndexOption(m.collectionName, indexName))
if err != nil {
return "", fmt.Errorf("[getIndexMetricsType] describe index failed, collection=%s, index=%s, %w",
m.collectionName, indexName, err)
}
typ, found := index.Params()[mindex.MetricTypeKey]
if !found { // unexpected
return "", fmt.Errorf("[getIndexMetricsType] invalid index params, collection=%s, index=%s", m.collectionName, indexName)
}
return mindex.MetricType(typ), nil
}