feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package elasticsearch
|
||||
|
||||
const (
|
||||
topK = 10
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package elasticsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Client es.Client
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) searchstore.Manager {
|
||||
return &esManager{config: config}
|
||||
}
|
||||
|
||||
type esManager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (e *esManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
cli := e.config.Client
|
||||
index := req.CollectionName
|
||||
indexExists, err := cli.Exists(ctx, index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if indexExists { // exists
|
||||
return nil
|
||||
}
|
||||
|
||||
properties := make(map[string]any)
|
||||
var foundID, foundCreatorID, foundTextContent bool
|
||||
for _, field := range req.Fields {
|
||||
switch field.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
case searchstore.FieldTextContent:
|
||||
foundTextContent = true
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
var property any
|
||||
switch field.Type {
|
||||
case searchstore.FieldTypeInt64:
|
||||
property = cli.Types().NewLongNumberProperty()
|
||||
case searchstore.FieldTypeText:
|
||||
property = cli.Types().NewTextProperty()
|
||||
default:
|
||||
return fmt.Errorf("[Create] es unsupported field type: %d", field.Type)
|
||||
}
|
||||
|
||||
properties[field.Name] = property
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
properties[searchstore.FieldID] = cli.Types().NewLongNumberProperty()
|
||||
}
|
||||
if !foundCreatorID {
|
||||
properties[searchstore.FieldCreatorID] = cli.Types().NewUnsignedLongNumberProperty()
|
||||
}
|
||||
if !foundTextContent {
|
||||
properties[searchstore.FieldTextContent] = cli.Types().NewTextProperty()
|
||||
}
|
||||
|
||||
err = cli.CreateIndex(ctx, index, properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *esManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
|
||||
cli := e.config.Client
|
||||
index := req.CollectionName
|
||||
|
||||
return cli.DeleteIndex(ctx, index)
|
||||
}
|
||||
|
||||
func (e *esManager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeTextStore
|
||||
}
|
||||
|
||||
func (e *esManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
return &esSearchStore{
|
||||
config: e.config,
|
||||
indexName: collectionName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *esManager) GetEmbedding() embedding.Embedder {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package elasticsearch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type esSearchStore struct {
|
||||
config *ManagerConfig
|
||||
indexName string
|
||||
}
|
||||
|
||||
func (e *esSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
implSpecOptions.ProgressBar.ReportError(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
cli := e.config.Client
|
||||
index := e.indexName
|
||||
bi, err := cli.NewBulkIndexer(index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = make([]string, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
fieldMapping, err := e.fromDocument(doc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := json.Marshal(fieldMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = bi.Add(ctx, es.BulkIndexerItem{
|
||||
Index: e.indexName,
|
||||
Action: "index",
|
||||
DocumentID: doc.ID,
|
||||
Body: bytes.NewReader(body),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, doc.ID)
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
if err = implSpecOptions.ProgressBar.AddN(1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = bi.Close(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
var (
|
||||
cli = e.config.Client
|
||||
index = e.indexName
|
||||
|
||||
options = retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
|
||||
implSpecOptions = retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
|
||||
req = &es.Request{
|
||||
Query: &es.Query{
|
||||
Bool: &es.BoolQuery{},
|
||||
},
|
||||
Size: options.TopK,
|
||||
}
|
||||
)
|
||||
|
||||
if implSpecOptions.MultiMatch == nil {
|
||||
req.Query.Bool.Must = append(req.Query.Bool.Must,
|
||||
es.NewMatchQuery(searchstore.FieldTextContent, query))
|
||||
} else {
|
||||
req.Query.Bool.Must = append(req.Query.Bool.Must,
|
||||
es.NewMultiMatchQuery(implSpecOptions.MultiMatch.Fields, query,
|
||||
"best_fields", es.Or))
|
||||
}
|
||||
|
||||
dsl, err := searchstore.LoadDSL(options.DSLInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = e.travDSL(req.Query, dsl); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if options.ScoreThreshold != nil {
|
||||
req.MinScore = options.ScoreThreshold
|
||||
}
|
||||
|
||||
resp, err := cli.Search(ctx, index, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docs, err := e.parseSearchResult(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
bi, err := e.config.Client.NewBulkIndexer(e.indexName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
if err = bi.Add(ctx, es.BulkIndexerItem{
|
||||
Index: e.indexName,
|
||||
Action: "delete",
|
||||
DocumentID: id,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return bi.Close(ctx)
|
||||
}
|
||||
|
||||
func (e *esSearchStore) travDSL(query *es.Query, dsl *searchstore.DSL) error {
|
||||
if dsl == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch dsl.Op {
|
||||
case searchstore.OpEq, searchstore.OpNe:
|
||||
arr := stringifyValue(dsl.Value)
|
||||
v := dsl.Value
|
||||
if len(arr) > 0 {
|
||||
v = arr[0]
|
||||
}
|
||||
|
||||
if dsl.Op == searchstore.OpEq {
|
||||
query.Bool.Must = append(query.Bool.Must,
|
||||
es.NewEqualQuery(dsl.Field, v))
|
||||
} else {
|
||||
query.Bool.MustNot = append(query.Bool.MustNot,
|
||||
es.NewEqualQuery(dsl.Field, v))
|
||||
}
|
||||
case searchstore.OpLike:
|
||||
s, ok := dsl.Value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("[travDSL] OpLike value should be string, but got %v", dsl.Value)
|
||||
}
|
||||
query.Bool.Must = append(query.Bool.Must, es.NewMatchQuery(dsl.Field, s))
|
||||
|
||||
case searchstore.OpIn:
|
||||
query.Bool.Must = append(query.Bool.MustNot,
|
||||
es.NewInQuery(dsl.Field, stringifyValue(dsl.Value)))
|
||||
|
||||
case searchstore.OpAnd, searchstore.OpOr:
|
||||
conds, ok := dsl.Value.([]*searchstore.DSL)
|
||||
if !ok {
|
||||
return fmt.Errorf("[travDSL] value type assertion failed for or")
|
||||
}
|
||||
|
||||
for _, cond := range conds {
|
||||
sub := &es.Query{}
|
||||
if err := e.travDSL(sub, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
if dsl.Op == searchstore.OpOr {
|
||||
query.Bool.Should = append(query.Bool.Should, *sub)
|
||||
} else {
|
||||
query.Bool.Must = append(query.Bool.Must, *sub)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("[trav] unknown op %s", dsl.Op)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) parseSearchResult(resp *es.Response) (docs []*schema.Document, err error) {
|
||||
docs = make([]*schema.Document, 0, len(resp.Hits.Hits))
|
||||
firstScore := 0.0
|
||||
for i, hit := range resp.Hits.Hits {
|
||||
var src map[string]any
|
||||
d := json.NewDecoder(bytes.NewReader(hit.Source_))
|
||||
d.UseNumber()
|
||||
if err = d.Decode(&src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ext := make(map[string]any)
|
||||
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
|
||||
|
||||
for field, val := range src {
|
||||
ok := true
|
||||
switch field {
|
||||
case searchstore.FieldTextContent:
|
||||
doc.Content, ok = val.(string)
|
||||
case searchstore.FieldCreatorID:
|
||||
var jn json.Number
|
||||
jn, ok = val.(json.Number)
|
||||
if ok {
|
||||
doc.MetaData[document.MetaDataKeyCreatorID], ok = assertJSONNumber(jn).(int64)
|
||||
}
|
||||
default:
|
||||
if jn, jok := val.(json.Number); jok {
|
||||
ext[field] = assertJSONNumber(jn)
|
||||
} else {
|
||||
ext[field] = val
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] type assertion failed, field=%s, val=%v", field, val)
|
||||
}
|
||||
}
|
||||
if hit.Id_ != nil {
|
||||
doc.ID = *hit.Id_
|
||||
}
|
||||
if hit.Score_ == nil { // unexpected
|
||||
return nil, fmt.Errorf("[parseSearchResult] es retrieve score not found")
|
||||
}
|
||||
score := float64(ptr.From(hit.Score_))
|
||||
if i == 0 {
|
||||
firstScore = score
|
||||
}
|
||||
doc.WithScore(score / firstScore)
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) fromDocument(doc *schema.Document) (map[string]any, error) {
|
||||
if doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("[fromDocument] es document meta data is nil")
|
||||
}
|
||||
|
||||
creatorID, ok := doc.MetaData[searchstore.FieldCreatorID].(int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[fromDocument] creator id not found or type invalid")
|
||||
}
|
||||
|
||||
fieldMapping := map[string]any{
|
||||
searchstore.FieldTextContent: doc.Content,
|
||||
searchstore.FieldCreatorID: creatorID,
|
||||
}
|
||||
|
||||
if ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any); ok {
|
||||
for k, v := range ext {
|
||||
fieldMapping[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return fieldMapping, nil
|
||||
}
|
||||
|
||||
func stringifyValue(dslValue any) []any {
|
||||
value := reflect.ValueOf(dslValue)
|
||||
switch value.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
length := value.Len()
|
||||
slice := make([]any, 0, length)
|
||||
for i := 0; i < length; i++ {
|
||||
elem := value.Index(i)
|
||||
switch elem.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
slice = append(slice, strconv.FormatInt(elem.Int(), 10))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
slice = append(slice, strconv.FormatUint(elem.Uint(), 10))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
slice = append(slice, strconv.FormatFloat(elem.Float(), 'f', -1, 64))
|
||||
case reflect.String:
|
||||
slice = append(slice, elem.String())
|
||||
default:
|
||||
slice = append(slice, elem) // do nothing
|
||||
}
|
||||
}
|
||||
return slice
|
||||
default:
|
||||
return []any{fmt.Sprintf("%v", value)}
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONNumber(f json.Number) any {
|
||||
if i64, err := f.Int64(); err == nil {
|
||||
return i64
|
||||
}
|
||||
if f64, err := f.Float64(); err == nil {
|
||||
return f64
|
||||
}
|
||||
return f.String()
|
||||
}
|
||||
22
backend/infra/impl/document/searchstore/milvus/consts.go
Normal file
22
backend/infra/impl/document/searchstore/milvus/consts.go
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
const (
|
||||
batchSize = 100
|
||||
topK = 4
|
||||
)
|
||||
119
backend/infra/impl/document/searchstore/milvus/convert.go
Normal file
119
backend/infra/impl/document/searchstore/milvus/convert.go
Normal file
@@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
func denseFieldName(name string) string {
|
||||
return fmt.Sprintf("dense_%s", name)
|
||||
}
|
||||
|
||||
func denseIndexName(name string) string {
|
||||
return fmt.Sprintf("index_dense_%s", name)
|
||||
}
|
||||
|
||||
func sparseFieldName(name string) string {
|
||||
return fmt.Sprintf("sparse_%s", name)
|
||||
}
|
||||
|
||||
func sparseIndexName(name string) string {
|
||||
return fmt.Sprintf("index_sparse_%s", name)
|
||||
}
|
||||
|
||||
func convertFieldType(typ searchstore.FieldType) (entity.FieldType, error) {
|
||||
switch typ {
|
||||
case searchstore.FieldTypeInt64:
|
||||
return entity.FieldTypeInt64, nil
|
||||
case searchstore.FieldTypeText:
|
||||
return entity.FieldTypeVarChar, nil
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
return entity.FieldTypeFloatVector, nil
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
return entity.FieldTypeSparseVector, nil
|
||||
default:
|
||||
return entity.FieldTypeNone, fmt.Errorf("[convertFieldType] unknown field type: %v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func convertDense(dense [][]float64) [][]float32 {
|
||||
return slices.Transform(dense, func(a []float64) []float32 {
|
||||
r := make([]float32, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
r[i] = float32(a[i])
|
||||
}
|
||||
return r
|
||||
})
|
||||
}
|
||||
|
||||
func convertMilvusDenseVector(dense [][]float64) []entity.Vector {
|
||||
return slices.Transform(dense, func(a []float64) entity.Vector {
|
||||
r := make([]float32, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
r[i] = float32(a[i])
|
||||
}
|
||||
return entity.FloatVector(r)
|
||||
})
|
||||
}
|
||||
|
||||
func convertSparse(sparse []map[int]float64) ([]entity.SparseEmbedding, error) {
|
||||
r := make([]entity.SparseEmbedding, 0, len(sparse))
|
||||
for _, s := range sparse {
|
||||
ks := make([]uint32, 0, len(s))
|
||||
vs := make([]float32, 0, len(s))
|
||||
for k, v := range s {
|
||||
ks = append(ks, uint32(k))
|
||||
vs = append(vs, float32(v))
|
||||
}
|
||||
|
||||
se, err := entity.NewSliceSparseEmbedding(ks, vs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r = append(r, se)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func convertMilvusSparseVector(sparse []map[int]float64) ([]entity.Vector, error) {
|
||||
r := make([]entity.Vector, 0, len(sparse))
|
||||
for _, s := range sparse {
|
||||
ks := make([]uint32, 0, len(s))
|
||||
vs := make([]float32, 0, len(s))
|
||||
for k, v := range s {
|
||||
ks = append(ks, uint32(k))
|
||||
vs = append(vs, float32(v))
|
||||
}
|
||||
|
||||
se, err := entity.NewSliceSparseEmbedding(ks, vs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r = append(r, se)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
334
backend/infra/impl/document/searchstore/milvus/milvus_manager.go
Normal file
334
backend/infra/impl/document/searchstore/milvus/milvus_manager.go
Normal file
@@ -0,0 +1,334 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
mentity "github.com/milvus-io/milvus/client/v2/entity"
|
||||
mindex "github.com/milvus-io/milvus/client/v2/index"
|
||||
client "github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Client *client.Client // required
|
||||
Embedding embedding.Embedder // required
|
||||
|
||||
EnableHybrid *bool // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
||||
DenseIndex mindex.Index // optional: default HNSW, M=30, efConstruction=360
|
||||
DenseMetric mentity.MetricType // optional: default IP
|
||||
SparseIndex mindex.Index // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
|
||||
SparseMetric mentity.MetricType // optional: default IP
|
||||
ShardNum int // optional: default 1
|
||||
BatchSize int // optional: default 100
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
||||
if config.Client == nil {
|
||||
return nil, fmt.Errorf("[NewManager] milvus client not provided")
|
||||
}
|
||||
if config.Embedding == nil {
|
||||
return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
|
||||
}
|
||||
|
||||
enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
||||
if config.EnableHybrid == nil {
|
||||
config.EnableHybrid = ptr.Of(enableSparse)
|
||||
} else if !enableSparse && ptr.From(config.EnableHybrid) {
|
||||
logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
|
||||
config.EnableHybrid = ptr.Of(false)
|
||||
}
|
||||
if config.DenseMetric == "" {
|
||||
config.DenseMetric = mentity.IP
|
||||
}
|
||||
if config.DenseIndex == nil {
|
||||
config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
|
||||
}
|
||||
if config.SparseMetric == "" {
|
||||
config.SparseMetric = mentity.IP
|
||||
}
|
||||
if config.SparseIndex == nil {
|
||||
config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
|
||||
}
|
||||
if config.ShardNum == 0 {
|
||||
config.ShardNum = 1
|
||||
}
|
||||
if config.BatchSize == 0 {
|
||||
config.BatchSize = 100
|
||||
}
|
||||
|
||||
return &milvusManager{config: config}, nil
|
||||
}
|
||||
|
||||
type milvusManager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if err := m.createCollection(ctx, req); err != nil {
|
||||
return fmt.Errorf("[Create] create collection failed, %w", err)
|
||||
}
|
||||
|
||||
if err := m.createIndexes(ctx, req); err != nil {
|
||||
return fmt.Errorf("[Create] create indexes failed, %w", err)
|
||||
}
|
||||
|
||||
if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
|
||||
return fmt.Errorf("[Create] load collection failed, %w", err)
|
||||
} else if !exists {
|
||||
return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
|
||||
return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
|
||||
}
|
||||
|
||||
func (m *milvusManager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
if exists, err := m.loadCollection(ctx, collectionName); err != nil {
|
||||
return nil, err
|
||||
} else if !exists {
|
||||
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
||||
errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
|
||||
}
|
||||
|
||||
return &milvusSearchStore{
|
||||
config: m.config,
|
||||
collectionName: collectionName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if req.CollectionName == "" || len(req.Fields) == 0 {
|
||||
return fmt.Errorf("[createCollection] invalid request params")
|
||||
}
|
||||
|
||||
cli := m.config.Client
|
||||
collectionName := req.CollectionName
|
||||
has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
|
||||
}
|
||||
if has {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields, err := m.convertFields(req.Fields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
|
||||
CollectionName: collectionName,
|
||||
Description: fmt.Sprintf("created by coze"),
|
||||
AutoID: false,
|
||||
Fields: fields,
|
||||
EnableDynamicField: false,
|
||||
}).WithShardNum(int32(m.config.ShardNum))
|
||||
|
||||
for k, v := range req.CollectionMeta {
|
||||
opt.WithProperty(k, v)
|
||||
}
|
||||
|
||||
if err = cli.CreateCollection(ctx, opt); err != nil {
|
||||
return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
collectionName := req.CollectionName
|
||||
indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "index not found") {
|
||||
return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
createdIndexes := sets.FromSlice(indexes)
|
||||
|
||||
var ops []func() error
|
||||
for i := range req.Fields {
|
||||
f := req.Fields[i]
|
||||
if !f.Indexing {
|
||||
continue
|
||||
}
|
||||
|
||||
ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
|
||||
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
|
||||
}
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
if err := op(); err != nil {
|
||||
return fmt.Errorf("[createIndexes] failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
|
||||
return func() error {
|
||||
if _, found := createdIndexes[indexName]; found {
|
||||
logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
|
||||
collectionName, fieldName, indexName, idx.IndexType())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
cli := m.config.Client
|
||||
|
||||
task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
|
||||
}
|
||||
|
||||
if err = task.Await(ctx); err != nil {
|
||||
return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
|
||||
collectionName, fieldName, indexName, idx.IndexType())
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
|
||||
cli := m.config.Client
|
||||
|
||||
stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
|
||||
}
|
||||
|
||||
switch stat.State {
|
||||
case mentity.LoadStateNotLoad:
|
||||
task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
|
||||
}
|
||||
if err = task.Await(ctx); err != nil {
|
||||
return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
|
||||
}
|
||||
return true, nil
|
||||
case mentity.LoadStateLoaded:
|
||||
return true, nil
|
||||
case mentity.LoadStateLoading:
|
||||
return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
|
||||
case mentity.LoadStateUnloading:
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
|
||||
var foundID, foundCreatorID bool
|
||||
resp := make([]*mentity.Field, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
switch f.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
default:
|
||||
}
|
||||
|
||||
if f.Indexing {
|
||||
if f.Type != searchstore.FieldTypeText {
|
||||
return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
|
||||
}
|
||||
// indexing 时只有 content 存储原文
|
||||
if f.Name == searchstore.FieldTextContent {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(f.Name).
|
||||
WithDescription(f.Description).
|
||||
WithIsPrimaryKey(f.IsPrimary).
|
||||
WithNullable(f.Nullable).
|
||||
WithDataType(mentity.FieldTypeVarChar).
|
||||
WithMaxLength(65535))
|
||||
}
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(denseFieldName(f.Name)).
|
||||
WithDataType(mentity.FieldTypeFloatVector).
|
||||
WithDim(m.config.Embedding.Dimensions()))
|
||||
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(sparseFieldName(f.Name)).
|
||||
WithDataType(mentity.FieldTypeSparseVector))
|
||||
}
|
||||
} else {
|
||||
mf := mentity.NewField().
|
||||
WithName(f.Name).
|
||||
WithDescription(f.Description).
|
||||
WithIsPrimaryKey(f.IsPrimary).
|
||||
WithNullable(f.Nullable)
|
||||
typ, err := convertFieldType(f.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mf.WithDataType(typ)
|
||||
if typ == mentity.FieldTypeVarChar {
|
||||
mf.WithMaxLength(65535)
|
||||
} else if typ == mentity.FieldTypeFloatVector {
|
||||
mf.WithDim(m.config.Embedding.Dimensions())
|
||||
}
|
||||
resp = append(resp, mf)
|
||||
}
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(searchstore.FieldID).
|
||||
WithDataType(mentity.FieldTypeInt64).
|
||||
WithIsPrimaryKey(true).
|
||||
WithNullable(false))
|
||||
}
|
||||
|
||||
if !foundCreatorID {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(searchstore.FieldCreatorID).
|
||||
WithDataType(mentity.FieldTypeInt64).
|
||||
WithNullable(false))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) GetEmbedding() embedding.Embedder {
|
||||
return m.config.Embedding
|
||||
}
|
||||
@@ -0,0 +1,600 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
mentity "github.com/milvus-io/milvus/client/v2/entity"
|
||||
mindex "github.com/milvus-io/milvus/client/v2/index"
|
||||
client "github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/slongfield/pyfmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type milvusSearchStore struct {
|
||||
config *ManagerConfig
|
||||
|
||||
collectionName string
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
implSpecOptions.ProgressBar.ReportError(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
indexingFields := make(sets.Set[string])
|
||||
for _, field := range implSpecOptions.IndexingFields {
|
||||
indexingFields[field] = struct{}{}
|
||||
}
|
||||
|
||||
if implSpecOptions.Partition != nil {
|
||||
partition := *implSpecOptions.Partition
|
||||
hasPartition, err := m.config.Client.HasPartition(ctx, client.NewHasPartitionOption(m.collectionName, partition))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] HasPartition failed, %w", err)
|
||||
}
|
||||
if !hasPartition {
|
||||
if err = m.config.Client.CreatePartition(ctx, client.NewCreatePartitionOption(m.collectionName, partition)); err != nil {
|
||||
return nil, fmt.Errorf("[Store] CreatePartition failed, %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range slices.Chunks(docs, batchSize) {
|
||||
columns, err := m.documents2Columns(ctx, part, indexingFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createReq := client.NewColumnBasedInsertOption(m.collectionName, columns...)
|
||||
if implSpecOptions.Partition != nil {
|
||||
createReq.WithPartition(*implSpecOptions.Partition)
|
||||
}
|
||||
|
||||
result, err := m.config.Client.Upsert(ctx, createReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] upsert failed, %w", err)
|
||||
}
|
||||
|
||||
partIDs := result.IDs
|
||||
for i := 0; i < partIDs.Len(); i++ {
|
||||
var sid string
|
||||
if partIDs.Type() == mentity.FieldTypeInt64 {
|
||||
id, err := partIDs.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sid = strconv.FormatInt(id, 10)
|
||||
} else {
|
||||
sid, err = partIDs.GetAsString(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
ids = append(ids, sid)
|
||||
}
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
if err = implSpecOptions.ProgressBar.AddN(len(part)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
cli := m.config.Client
|
||||
emb := m.config.Embedding
|
||||
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
|
||||
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
|
||||
|
||||
desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
dense [][]float64
|
||||
sparse []map[int]float64
|
||||
expr string
|
||||
result []client.ResultSet
|
||||
|
||||
fields = desc.Schema.Fields
|
||||
outputFields []string
|
||||
enableSparse = m.enableSparse(fields)
|
||||
)
|
||||
|
||||
if options.DSLInfo != nil {
|
||||
expr, err = m.dsl2Expr(options.DSLInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if enableSparse {
|
||||
dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] EmbedStringsHybrid failed, %w", err)
|
||||
}
|
||||
} else {
|
||||
dense, err = emb.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] EmbedStrings failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dv := convertMilvusDenseVector(dense)
|
||||
sv, err := convertMilvusSparseVector(sparse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
outputFields = append(outputFields, field.Name)
|
||||
}
|
||||
|
||||
var scoreNormType *mindex.MetricType
|
||||
|
||||
if enableSparse {
|
||||
var annRequests []*client.AnnRequest
|
||||
for _, field := range fields {
|
||||
var (
|
||||
vector []mentity.Vector
|
||||
metricsType mindex.MetricType
|
||||
)
|
||||
if field.DataType == mentity.FieldTypeFloatVector {
|
||||
vector = dv
|
||||
metricsType, err = m.getIndexMetricsType(ctx, denseIndexName(field.Name))
|
||||
} else if field.DataType == mentity.FieldTypeSparseVector {
|
||||
vector = sv
|
||||
metricsType, err = m.getIndexMetricsType(ctx, sparseIndexName(field.Name))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
annRequests = append(annRequests,
|
||||
client.NewAnnRequest(field.Name, ptr.From(options.TopK), vector...).
|
||||
WithSearchParam(mindex.MetricTypeKey, string(metricsType)).
|
||||
WithFilter(expr),
|
||||
)
|
||||
}
|
||||
|
||||
searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
|
||||
WithPartitons(implSpecOptions.Partitions...).
|
||||
WithReranker(client.NewRRFReranker()).
|
||||
WithOutputFields(outputFields...)
|
||||
|
||||
result, err = cli.HybridSearch(ctx, searchOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] HybridSearch failed, %w", err)
|
||||
}
|
||||
} else {
|
||||
indexes, err := cli.ListIndexes(ctx, client.NewListIndexOption(m.collectionName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] ListIndexes failed, %w", err)
|
||||
}
|
||||
if len(indexes) != 1 {
|
||||
return nil, fmt.Errorf("[Retrieve] restrict single index ann search, but got %d, collection=%s",
|
||||
len(indexes), m.collectionName)
|
||||
}
|
||||
metricsType, err := m.getIndexMetricsType(ctx, indexes[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scoreNormType = &metricsType
|
||||
searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
|
||||
WithPartitions(implSpecOptions.Partitions...).
|
||||
WithFilter(expr).
|
||||
WithOutputFields(outputFields...).
|
||||
WithSearchParam(mindex.MetricTypeKey, string(metricsType))
|
||||
result, err = cli.Search(ctx, searchOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] Search failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
docs, err := m.resultSet2Document(result, scoreNormType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] resultSet2Document failed, %w", err)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
int64IDs := make([]int64, 0, len(ids))
|
||||
for _, sid := range ids {
|
||||
id, err := strconv.ParseInt(sid, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
int64IDs = append(int64IDs, id)
|
||||
}
|
||||
_, err := m.config.Client.Delete(ctx,
|
||||
client.NewDeleteOption(m.collectionName).WithInt64IDs(searchstore.FieldID, int64IDs))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) documents2Columns(ctx context.Context, docs []*schema.Document, indexingFields sets.Set[string]) (
|
||||
cols []column.Column, err error) {
|
||||
|
||||
var (
|
||||
ids []int64
|
||||
contents []string
|
||||
creatorIDs []int64
|
||||
emptyContents = true
|
||||
)
|
||||
|
||||
colMapping := map[string]any{}
|
||||
colTypeMapping := map[string]searchstore.FieldType{
|
||||
searchstore.FieldID: searchstore.FieldTypeInt64,
|
||||
searchstore.FieldCreatorID: searchstore.FieldTypeInt64,
|
||||
searchstore.FieldTextContent: searchstore.FieldTypeText,
|
||||
}
|
||||
for _, doc := range docs {
|
||||
if doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("[documents2Columns] meta data is nil")
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[documents2Columns] parse id failed, %w", err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
contents = append(contents, doc.Content)
|
||||
if doc.Content != "" {
|
||||
emptyContents = false
|
||||
}
|
||||
|
||||
creatorID, err := document.GetDocumentCreatorID(doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[documents2Columns] creator_id not found or type invalid., %w", err)
|
||||
}
|
||||
creatorIDs = append(creatorIDs, creatorID)
|
||||
|
||||
ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for field := range ext {
|
||||
val := ext[field]
|
||||
container := colMapping[field]
|
||||
switch t := val.(type) {
|
||||
case uint, uint8, uint16, uint32, uint64, uintptr:
|
||||
var c []int64
|
||||
if container == nil {
|
||||
colTypeMapping[field] = searchstore.FieldTypeInt64
|
||||
} else {
|
||||
c, ok = container.([]int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, int64(reflect.ValueOf(t).Uint()))
|
||||
colMapping[field] = c
|
||||
case int, int8, int16, int32, int64:
|
||||
var c []int64
|
||||
if container == nil {
|
||||
colTypeMapping[field] = searchstore.FieldTypeInt64
|
||||
} else {
|
||||
c, ok = container.([]int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, reflect.ValueOf(t).Int())
|
||||
colMapping[field] = c
|
||||
case string:
|
||||
var c []string
|
||||
if container == nil {
|
||||
colTypeMapping[field] = searchstore.FieldTypeText
|
||||
} else {
|
||||
c, ok = container.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, t)
|
||||
colMapping[field] = c
|
||||
case []float64:
|
||||
var c [][]float64
|
||||
if container == nil {
|
||||
container = c
|
||||
colTypeMapping[field] = searchstore.FieldTypeDenseVector
|
||||
} else {
|
||||
c, ok = container.([][]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, t)
|
||||
colMapping[field] = c
|
||||
case map[int]float64:
|
||||
var c []map[int]float64
|
||||
if container == nil {
|
||||
container = c
|
||||
colTypeMapping[field] = searchstore.FieldTypeSparseVector
|
||||
} else {
|
||||
c, ok = container.([]map[int]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, t)
|
||||
colMapping[field] = c
|
||||
default:
|
||||
return nil, fmt.Errorf("[documents2Columns] val type not support, val=%v", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
colMapping[searchstore.FieldID] = ids
|
||||
colMapping[searchstore.FieldCreatorID] = creatorIDs
|
||||
colMapping[searchstore.FieldTextContent] = contents
|
||||
|
||||
for fieldName, container := range colMapping {
|
||||
colType := colTypeMapping[fieldName]
|
||||
switch colType {
|
||||
case searchstore.FieldTypeInt64:
|
||||
c, ok := container.([]int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
cols = append(cols, column.NewColumnInt64(fieldName, c))
|
||||
case searchstore.FieldTypeText:
|
||||
c, ok := container.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not string")
|
||||
}
|
||||
|
||||
if _, indexing := indexingFields[fieldName]; indexing {
|
||||
if fieldName == searchstore.FieldTextContent && !emptyContents {
|
||||
cols = append(cols, column.NewColumnVarChar(fieldName, c))
|
||||
}
|
||||
|
||||
var (
|
||||
emb = m.config.Embedding
|
||||
dense [][]float64
|
||||
sparse []map[int]float64
|
||||
)
|
||||
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
dense, sparse, err = emb.EmbedStringsHybrid(ctx, c)
|
||||
} else {
|
||||
dense, err = emb.EmbedStrings(ctx, c)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[slices2Columns] embed failed, %w", err)
|
||||
}
|
||||
|
||||
cols = append(cols, column.NewColumnFloatVector(denseFieldName(fieldName), int(emb.Dimensions()), convertDense(dense)))
|
||||
|
||||
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
s, err := convertSparse(sparse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols = append(cols, column.NewColumnSparseVectors(sparseFieldName(fieldName), s))
|
||||
}
|
||||
} else {
|
||||
cols = append(cols, column.NewColumnVarChar(fieldName, c))
|
||||
}
|
||||
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
c, ok := container.([][]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not []float64")
|
||||
}
|
||||
cols = append(cols, column.NewColumnFloatVector(fieldName, int(m.config.Embedding.Dimensions()), convertDense(c)))
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
c, ok := container.([]map[int]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not map[int]float64")
|
||||
}
|
||||
sparse, err := convertSparse(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols = append(cols, column.NewColumnSparseVectors(fieldName, sparse))
|
||||
default:
|
||||
return nil, fmt.Errorf("[documents2Columns] column type not support, type=%d", colType)
|
||||
}
|
||||
}
|
||||
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) resultSet2Document(result []client.ResultSet, metricsType *mindex.MetricType) (docs []*schema.Document, err error) {
|
||||
docs = make([]*schema.Document, 0, len(result))
|
||||
minScore := math.MaxFloat64
|
||||
maxScore := 0.0
|
||||
|
||||
for _, r := range result {
|
||||
for i := 0; i < r.ResultCount; i++ {
|
||||
ext := make(map[string]any)
|
||||
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
|
||||
score := float64(r.Scores[i])
|
||||
minScore = min(minScore, score)
|
||||
maxScore = max(maxScore, score)
|
||||
doc.WithScore(score)
|
||||
|
||||
for _, field := range r.Fields {
|
||||
switch field.Name() {
|
||||
case searchstore.FieldID:
|
||||
id, err := field.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doc.ID = strconv.FormatInt(id, 10)
|
||||
case searchstore.FieldTextContent:
|
||||
doc.Content, err = field.GetAsString(i)
|
||||
case searchstore.FieldCreatorID:
|
||||
doc.MetaData[document.MetaDataKeyCreatorID], err = field.GetAsInt64(i)
|
||||
default:
|
||||
ext[field.Name()], err = field.Get(i)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
return docs[i].Score() > docs[j].Score()
|
||||
})
|
||||
|
||||
// norm score
|
||||
if (m.config.EnableHybrid != nil && *m.config.EnableHybrid) || metricsType == nil {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
switch *metricsType {
|
||||
case mentity.L2:
|
||||
base := maxScore - minScore
|
||||
for i := range docs {
|
||||
if base == 0 {
|
||||
docs[i].WithScore(1.0)
|
||||
} else {
|
||||
docs[i].WithScore(1.0 - (docs[i].Score()-minScore)/base)
|
||||
}
|
||||
}
|
||||
docs = slices.Reverse(docs)
|
||||
case mentity.IP, mentity.COSINE:
|
||||
for i := range docs {
|
||||
docs[i].WithScore((docs[i].Score() + 1) / 2)
|
||||
}
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) enableSparse(fields []*mentity.Field) bool {
|
||||
found := false
|
||||
for _, field := range fields {
|
||||
if field.DataType == mentity.FieldTypeSparseVector {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return found && *m.config.EnableHybrid && m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) dsl2Expr(src map[string]interface{}) (string, error) {
|
||||
if src == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
dsl, err := searchstore.LoadDSL(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var travDSL func(dsl *searchstore.DSL) (string, error)
|
||||
travDSL = func(dsl *searchstore.DSL) (string, error) {
|
||||
kv := map[string]interface{}{
|
||||
"field": dsl.Field,
|
||||
"val": dsl.Value,
|
||||
}
|
||||
|
||||
switch dsl.Op {
|
||||
case searchstore.OpEq:
|
||||
return pyfmt.Fmt("{field} == {val}", kv)
|
||||
case searchstore.OpNe:
|
||||
return pyfmt.Fmt("{field} != {val}", kv)
|
||||
case searchstore.OpLike:
|
||||
return pyfmt.Fmt("{field} LIKE {val}", kv)
|
||||
case searchstore.OpIn:
|
||||
b, err := json.Marshal(dsl.Value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
kv["val"] = string(b)
|
||||
return pyfmt.Fmt("{field} IN {val}", kv)
|
||||
case searchstore.OpAnd, searchstore.OpOr:
|
||||
sub, ok := dsl.Value.([]*searchstore.DSL)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("[dsl2Expr] invalid sub dsl")
|
||||
}
|
||||
var items []string
|
||||
for _, s := range sub {
|
||||
str, err := travDSL(s)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[dsl2Expr] parse sub failed, %w", err)
|
||||
}
|
||||
items = append(items, str)
|
||||
}
|
||||
|
||||
if dsl.Op == searchstore.OpAnd {
|
||||
return strings.Join(items, " AND "), nil
|
||||
} else {
|
||||
return strings.Join(items, " OR "), nil
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("[dsl2Expr] unknown op type=%s", dsl.Op)
|
||||
}
|
||||
}
|
||||
|
||||
return travDSL(dsl)
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) getIndexMetricsType(ctx context.Context, indexName string) (mindex.MetricType, error) {
|
||||
index, err := m.config.Client.DescribeIndex(ctx, client.NewDescribeIndexOption(m.collectionName, indexName))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[getIndexMetricsType] describe index failed, collection=%s, index=%s, %w",
|
||||
m.collectionName, indexName, err)
|
||||
}
|
||||
|
||||
typ, found := index.Params()[mindex.MetricTypeKey]
|
||||
if !found { // unexpected
|
||||
return "", fmt.Errorf("[getIndexMetricsType] invalid index params, collection=%s, index=%s", m.collectionName, indexName)
|
||||
}
|
||||
|
||||
return mindex.MetricType(typ), nil
|
||||
}
|
||||
122
backend/infra/impl/document/searchstore/vikingdb/consts.go
Normal file
122
backend/infra/impl/document/searchstore/vikingdb/consts.go
Normal file
@@ -0,0 +1,122 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
embcontract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type VikingEmbeddingModelName string
|
||||
|
||||
const (
|
||||
ModelNameDoubaoEmbedding VikingEmbeddingModelName = "doubao-embedding"
|
||||
ModelNameDoubaoEmbeddingLarge VikingEmbeddingModelName = "doubao-embedding-large"
|
||||
ModelNameDoubaoEmbeddingVision VikingEmbeddingModelName = "doubao-embedding-vision"
|
||||
ModelNameBGELargeZH VikingEmbeddingModelName = "bge-large-zh"
|
||||
ModelNameBGEM3 VikingEmbeddingModelName = "bge-m3"
|
||||
ModelNameBGEVisualizedM3 VikingEmbeddingModelName = "bge-visualized-m3"
|
||||
|
||||
//ModelNameDoubaoEmbeddingAndM3 VikingEmbeddingModelName = "doubao-embedding-and-m3"
|
||||
//ModelNameDoubaoEmbeddingLargeAndM3 VikingEmbeddingModelName = "doubao-embedding-large-and-m3"
|
||||
//ModelNameBGELargeZHAndM3 VikingEmbeddingModelName = "bge-large-zh-and-m3"
|
||||
)
|
||||
|
||||
func (v VikingEmbeddingModelName) Dimensions() int64 {
|
||||
switch v {
|
||||
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingVision:
|
||||
return 2048
|
||||
case ModelNameDoubaoEmbeddingLarge:
|
||||
return 4096
|
||||
case ModelNameBGELargeZH, ModelNameBGEM3, ModelNameBGEVisualizedM3:
|
||||
return 1024
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v VikingEmbeddingModelName) ModelVersion() *string {
|
||||
switch v {
|
||||
case ModelNameDoubaoEmbedding:
|
||||
return ptr.Of("240515")
|
||||
case ModelNameDoubaoEmbeddingLarge:
|
||||
return ptr.Of("240915")
|
||||
case ModelNameDoubaoEmbeddingVision:
|
||||
return ptr.Of("250328")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (v VikingEmbeddingModelName) SupportStatus() embcontract.SupportStatus {
|
||||
switch v {
|
||||
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingLarge, ModelNameDoubaoEmbeddingVision, ModelNameBGELargeZH, ModelNameBGEVisualizedM3:
|
||||
return embcontract.SupportDense
|
||||
case ModelNameBGEM3:
|
||||
return embcontract.SupportDenseAndSparse
|
||||
default:
|
||||
return embcontract.SupportDense
|
||||
}
|
||||
}
|
||||
|
||||
type IndexType string
|
||||
|
||||
const (
|
||||
IndexTypeHNSW IndexType = vikingdb.HNSW
|
||||
IndexTypeHNSWHybrid IndexType = vikingdb.HNSW_HYBRID
|
||||
IndexTypeFlat IndexType = vikingdb.FLAT
|
||||
IndexTypeIVF IndexType = vikingdb.IVF
|
||||
IndexTypeDiskANN IndexType = vikingdb.DiskANN
|
||||
)
|
||||
|
||||
type IndexDistance string
|
||||
|
||||
const (
|
||||
IndexDistanceIP IndexDistance = vikingdb.IP
|
||||
IndexDistanceL2 IndexDistance = vikingdb.L2
|
||||
IndexDistanceCosine IndexDistance = vikingdb.COSINE
|
||||
)
|
||||
|
||||
type IndexQuant string
|
||||
|
||||
const (
|
||||
IndexQuantInt8 IndexQuant = vikingdb.Int8
|
||||
IndexQuantFloat IndexQuant = vikingdb.Float
|
||||
IndexQuantFix16 IndexQuant = vikingdb.Fix16
|
||||
IndexQuantPQ IndexQuant = vikingdb.PQ
|
||||
)
|
||||
|
||||
const (
|
||||
vikingEmbeddingUseDense = "return_dense"
|
||||
vikingEmbeddingUseSparse = "return_sparse"
|
||||
vikingEmbeddingRespSentenceDense = "sentence_dense_embedding"
|
||||
vikingEmbeddingRespSentenceSparse = "sentence_sparse_embedding"
|
||||
vikingIndexName = "opencoze_index"
|
||||
)
|
||||
|
||||
const (
|
||||
errCollectionNotFound = "collection not found"
|
||||
errIndexNotFound = "index not found"
|
||||
)
|
||||
|
||||
func denseFieldName(name string) string {
|
||||
return fmt.Sprintf("dense_%s", name)
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Service *vikingdb.VikingDBService
|
||||
|
||||
IndexingConfig *VikingIndexingConfig
|
||||
EmbeddingConfig *VikingEmbeddingConfig
|
||||
|
||||
// TODO: cache viking collection & index client
|
||||
}
|
||||
|
||||
type VikingIndexingConfig struct {
|
||||
// vector index config
|
||||
Type IndexType // default: hnsw / hnsw_hybrid
|
||||
Distance *IndexDistance // default: ip
|
||||
Quant *IndexQuant // default: int8
|
||||
HnswM *int64 // default: 20
|
||||
HnswCef *int64 // default: 400
|
||||
HnswSef *int64 // default: 800
|
||||
|
||||
// others
|
||||
CpuQuota int64 // default: 2
|
||||
ShardCount int64 // default: 1
|
||||
}
|
||||
|
||||
type VikingEmbeddingConfig struct {
|
||||
UseVikingEmbedding bool
|
||||
EnableHybrid bool
|
||||
|
||||
// viking embedding config
|
||||
ModelName VikingEmbeddingModelName
|
||||
ModelVersion *string
|
||||
DenseWeight *float64
|
||||
|
||||
// builtin embedding config
|
||||
BuiltinEmbedding embedding.Embedder
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
||||
if config.Service == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
|
||||
}
|
||||
if config.EmbeddingConfig == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
|
||||
}
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
|
||||
}
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
|
||||
}
|
||||
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
|
||||
}
|
||||
if config.EmbeddingConfig.UseVikingEmbedding &&
|
||||
config.EmbeddingConfig.EnableHybrid &&
|
||||
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
|
||||
}
|
||||
if config.IndexingConfig == nil {
|
||||
config.IndexingConfig = &VikingIndexingConfig{}
|
||||
}
|
||||
if config.IndexingConfig.Type == "" {
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
|
||||
config.IndexingConfig.Type = IndexTypeHNSW
|
||||
} else {
|
||||
config.IndexingConfig.Type = IndexTypeHNSWHybrid
|
||||
}
|
||||
}
|
||||
if config.IndexingConfig.Distance == nil {
|
||||
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
|
||||
}
|
||||
if config.IndexingConfig.Quant == nil {
|
||||
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
|
||||
}
|
||||
if config.IndexingConfig.HnswM == nil {
|
||||
config.IndexingConfig.HnswM = ptr.Of(int64(20))
|
||||
}
|
||||
if config.IndexingConfig.HnswCef == nil {
|
||||
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
|
||||
}
|
||||
if config.IndexingConfig.HnswSef == nil {
|
||||
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
|
||||
}
|
||||
if config.IndexingConfig.CpuQuota == 0 {
|
||||
config.IndexingConfig.CpuQuota = 2
|
||||
}
|
||||
if config.IndexingConfig.ShardCount == 0 {
|
||||
config.IndexingConfig.ShardCount = 1
|
||||
}
|
||||
|
||||
return &manager{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if err := m.createCollection(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.createIndex(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
|
||||
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
|
||||
if !strings.Contains(err.Error(), errIndexNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
|
||||
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
collection, err := m.config.Service.GetCollection(collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &vkSearchStore{manager: m, collection: collection}, nil
|
||||
}
|
||||
|
||||
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
svc := m.config.Service
|
||||
|
||||
collection, err := svc.GetCollection(req.CollectionName)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
||||
return err
|
||||
}
|
||||
} else if collection != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields, vopts, err := m.mapFields(req.Fields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if vopts != nil {
|
||||
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
|
||||
} else {
|
||||
_, err = svc.CreateCollection(req.CollectionName, fields, "")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
svc := m.config.Service
|
||||
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), errIndexNotFound) {
|
||||
return err
|
||||
}
|
||||
} else if index != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
vectorIndex := &vikingdb.VectorIndexParams{
|
||||
IndexType: string(m.config.IndexingConfig.Type),
|
||||
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
|
||||
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
|
||||
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
|
||||
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
|
||||
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
|
||||
}
|
||||
|
||||
opts := vikingdb.NewIndexOptions().
|
||||
SetVectorIndex(vectorIndex).
|
||||
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
|
||||
SetShardCount(m.config.IndexingConfig.ShardCount)
|
||||
|
||||
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
|
||||
var (
|
||||
foundID bool
|
||||
foundCreatorID bool
|
||||
dstFields = make([]vikingdb.Field, 0, len(srcFields))
|
||||
vectorizeOpts []*vikingdb.VectorizeTuple
|
||||
embConfig = m.config.EmbeddingConfig
|
||||
)
|
||||
|
||||
for _, srcField := range srcFields {
|
||||
switch srcField.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
default:
|
||||
}
|
||||
|
||||
if srcField.Indexing {
|
||||
if srcField.Type != searchstore.FieldTypeText {
|
||||
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
|
||||
}
|
||||
if embConfig.UseVikingEmbedding {
|
||||
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name))
|
||||
if embConfig.EnableHybrid {
|
||||
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name))
|
||||
}
|
||||
vectorizeOpts = append(vectorizeOpts, vt)
|
||||
} else {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: denseFieldName(srcField.Name),
|
||||
FieldType: vikingdb.Vector,
|
||||
DefaultVal: nil,
|
||||
Dim: m.getDims(),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
dstField := vikingdb.Field{
|
||||
FieldName: srcField.Name,
|
||||
IsPrimaryKey: srcField.IsPrimary,
|
||||
}
|
||||
switch srcField.Type {
|
||||
case searchstore.FieldTypeInt64:
|
||||
dstField.FieldType = vikingdb.Int64
|
||||
case searchstore.FieldTypeText:
|
||||
dstField.FieldType = vikingdb.Text
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
dstField.FieldType = vikingdb.Vector
|
||||
dstField.Dim = m.getDims()
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
dstField.FieldType = vikingdb.Sparse_Vector
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
|
||||
}
|
||||
dstFields = append(dstFields, dstField)
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: searchstore.FieldID,
|
||||
FieldType: vikingdb.Int64,
|
||||
IsPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if !foundCreatorID {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: searchstore.FieldCreatorID,
|
||||
FieldType: vikingdb.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
return dstFields, vectorizeOpts, nil
|
||||
}
|
||||
|
||||
func (m *manager) newVectorizeModelConf(fieldName string) *vikingdb.VectorizeModelConf {
|
||||
embConfig := m.config.EmbeddingConfig
|
||||
vmc := vikingdb.NewVectorizeModelConf().
|
||||
SetTextField(fieldName).
|
||||
SetModelName(string(embConfig.ModelName)).
|
||||
SetDim(m.getDims())
|
||||
if embConfig.ModelVersion != nil {
|
||||
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
|
||||
}
|
||||
return vmc
|
||||
}
|
||||
|
||||
func (m *manager) getDims() int64 {
|
||||
if m.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
return m.config.EmbeddingConfig.ModelName.Dimensions()
|
||||
}
|
||||
|
||||
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type vkSearchStore struct {
|
||||
*manager
|
||||
collection *vikingdb.Collection
|
||||
index *vikingdb.Index
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
_ = implSpecOptions.ProgressBar.ReportError(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
docsWithoutVector, err := slices.TransformWithErrorCheck(docs, v.document2DataWithoutVector)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] vikingdb failed to transform documents, %w", err)
|
||||
}
|
||||
|
||||
indexingFields := sets.FromSlice(implSpecOptions.IndexingFields)
|
||||
for _, part := range slices.Chunks(docsWithoutVector, 100) {
|
||||
docsWithVector, err := v.addEmbedding(ctx, part, indexingFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := v.collection.UpsertData(docsWithVector); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ids = slices.Transform(docs, func(a *schema.Document) string { return a.ID })
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (docs []*schema.Document, err error) {
|
||||
indexClient := v.index
|
||||
if indexClient == nil {
|
||||
foundIndex := false
|
||||
for _, index := range v.collection.Indexes {
|
||||
if index.IndexName == vikingIndexName {
|
||||
foundIndex = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundIndex {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb index not found, name=%s", vikingIndexName)
|
||||
}
|
||||
|
||||
indexClient, err = v.config.Service.GetIndex(v.collection.CollectionName, vikingIndexName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb failed to get index, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(4)}, opts...)
|
||||
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
|
||||
|
||||
searchOpts := vikingdb.NewSearchOptions().
|
||||
SetLimit(int64(ptr.From(options.TopK))).
|
||||
SetText(query).
|
||||
SetRetry(true)
|
||||
|
||||
filter, err := v.genFilter(ctx, options, implSpecOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb failed to build filter, %w", err)
|
||||
}
|
||||
if filter != nil {
|
||||
// 不支持跨 partition 召回,使用 filter 替代
|
||||
searchOpts = searchOpts.SetFilter(filter)
|
||||
}
|
||||
|
||||
var data []*vikingdb.Data
|
||||
|
||||
if v.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
data, err = indexClient.SearchWithMultiModal(searchOpts)
|
||||
} else {
|
||||
var dense [][]float64
|
||||
dense, err = v.config.EmbeddingConfig.BuiltinEmbedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] embed failed, %w", err)
|
||||
}
|
||||
if len(dense) != 1 {
|
||||
return nil, fmt.Errorf("[Retrieve] unexpected dense vector size, expected=1, got=%d", len(dense))
|
||||
}
|
||||
data, err = indexClient.SearchByVector(dense[0], searchOpts)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb search failed, %w", err)
|
||||
}
|
||||
|
||||
docs, err = v.parseSearchResult(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
for _, part := range slices.Chunks(ids, 100) {
|
||||
if err := v.collection.DeleteData(part); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) document2DataWithoutVector(doc *schema.Document) (data vikingdb.Data, err error) {
|
||||
creatorID, err := document.GetDocumentCreatorID(doc)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
|
||||
docID, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
searchstore.FieldID: docID,
|
||||
searchstore.FieldCreatorID: creatorID,
|
||||
searchstore.FieldTextContent: doc.Content,
|
||||
}
|
||||
|
||||
if ext, err := document.GetDocumentExternalStorage(doc); err == nil { // try load
|
||||
for key, val := range ext {
|
||||
fields[key] = val
|
||||
}
|
||||
}
|
||||
return vikingdb.Data{
|
||||
Id: doc.ID,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) addEmbedding(ctx context.Context, rows []vikingdb.Data, indexingFields map[string]struct{}) ([]vikingdb.Data, error) {
|
||||
if v.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
emb := v.config.EmbeddingConfig.BuiltinEmbedding
|
||||
|
||||
for indexingField := range indexingFields {
|
||||
values := make([]string, len(rows))
|
||||
for i, row := range rows {
|
||||
val, found := row.Fields[indexingField]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("[addEmbedding] indexing field not found in document, field=%s", indexingField)
|
||||
}
|
||||
|
||||
strVal, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[addEmbedding] val not string, field=%s, val=%v", indexingField, val)
|
||||
}
|
||||
|
||||
values[i] = strVal
|
||||
}
|
||||
|
||||
dense, err := emb.EmbedStrings(ctx, values)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[addEmbedding] failed to embed, %w", err)
|
||||
}
|
||||
if len(dense) != len(values) {
|
||||
return nil, fmt.Errorf("[addEmbedding] unexpected dense vector size, expected=%d, got=%d", len(values), len(dense))
|
||||
}
|
||||
|
||||
df := denseFieldName(indexingField)
|
||||
for i := range dense {
|
||||
rows[i].Fields[df] = dense[i]
|
||||
}
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) parseSearchResult(result []*vikingdb.Data) ([]*schema.Document, error) {
|
||||
docs := make([]*schema.Document, 0, len(result))
|
||||
for _, data := range result {
|
||||
ext := make(map[string]any)
|
||||
doc := document.WithDocumentExternalStorage(&schema.Document{MetaData: map[string]any{}}, ext).
|
||||
WithScore(data.Score)
|
||||
|
||||
for field, val := range data.Fields {
|
||||
switch field {
|
||||
case searchstore.FieldID:
|
||||
jn, ok := val.(json.Number)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] id type assertion failed, val=%v", val)
|
||||
}
|
||||
doc.ID = jn.String()
|
||||
case searchstore.FieldCreatorID:
|
||||
jn, ok := val.(json.Number)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] creator_id type assertion failed, val=%v", val)
|
||||
}
|
||||
creatorID, err := jn.Int64()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseSearchResult] creator_id value not int64, val=%v", jn.String())
|
||||
}
|
||||
doc = document.WithDocumentCreatorID(doc, creatorID)
|
||||
case searchstore.FieldTextContent:
|
||||
text, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] content value not string, val=%v", val)
|
||||
}
|
||||
doc.Content = text
|
||||
default:
|
||||
switch t := val.(type) {
|
||||
case json.Number:
|
||||
if i64, err := t.Int64(); err == nil {
|
||||
ext[field] = i64
|
||||
} else if f64, err := t.Float64(); err == nil {
|
||||
ext[field] = f64
|
||||
} else {
|
||||
ext[field] = t.String()
|
||||
}
|
||||
default:
|
||||
ext[field] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) genFilter(ctx context.Context, co *retriever.Options, ro *searchstore.RetrieverOptions) (map[string]any, error) {
|
||||
filter, err := v.dsl2Filter(ctx, co.DSLInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ro.PartitionKey != nil && len(ro.Partitions) > 0 {
|
||||
var (
|
||||
key = ptr.From(ro.PartitionKey)
|
||||
fieldType = ""
|
||||
conds any
|
||||
)
|
||||
for _, field := range v.collection.Fields {
|
||||
if field.FieldName == key {
|
||||
fieldType = field.FieldType
|
||||
}
|
||||
}
|
||||
if fieldType == "" {
|
||||
return nil, fmt.Errorf("[Retrieve] partition key not found, key=%s", key)
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case vikingdb.Int64:
|
||||
c := make([]int64, 0, len(ro.Partitions))
|
||||
for _, item := range ro.Partitions {
|
||||
i64, err := strconv.ParseInt(item, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] partition value parse error, key=%s, val=%v, err=%v", key, item, err)
|
||||
}
|
||||
c = append(c, i64)
|
||||
}
|
||||
conds = c
|
||||
case vikingdb.String:
|
||||
conds = ro.Partitions
|
||||
default:
|
||||
return nil, fmt.Errorf("[Retrieve] invalid field type for partition, key=%s, type=%s", key, fieldType)
|
||||
}
|
||||
|
||||
op := map[string]any{"op": "must", "field": key, "conds": conds}
|
||||
|
||||
if filter != nil {
|
||||
filter = op
|
||||
} else {
|
||||
filter = map[string]any{
|
||||
"op": "and",
|
||||
"conds": []map[string]any{op, filter},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) dsl2Filter(ctx context.Context, src map[string]any) (map[string]any, error) {
|
||||
dsl, err := searchstore.LoadDSL(src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dsl == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
toSliceValue := func(val any) any {
|
||||
if reflect.TypeOf(val).Kind() == reflect.Slice {
|
||||
return val
|
||||
}
|
||||
return []any{val}
|
||||
}
|
||||
|
||||
var filter map[string]any
|
||||
|
||||
switch dsl.Op {
|
||||
case searchstore.OpEq, searchstore.OpIn:
|
||||
filter = map[string]any{
|
||||
"op": "must",
|
||||
"field": dsl.Field,
|
||||
"conds": toSliceValue(dsl.Value),
|
||||
}
|
||||
case searchstore.OpNe:
|
||||
filter = map[string]any{
|
||||
"op": "must_not",
|
||||
"field": dsl.Field,
|
||||
"conds": toSliceValue(dsl.Value),
|
||||
}
|
||||
case searchstore.OpLike:
|
||||
logs.CtxWarnf(ctx, "[dsl2Filter] vikingdb invalid dsl type, skip, type=%s", dsl.Op)
|
||||
case searchstore.OpAnd, searchstore.OpOr:
|
||||
var conds []map[string]any
|
||||
sub, ok := dsl.Value.([]map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[dsl2Filter] invalid value for and/or, should be []map[string]any")
|
||||
}
|
||||
for _, subDSL := range sub {
|
||||
cond, err := v.dsl2Filter(ctx, subDSL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conds = append(conds, cond)
|
||||
}
|
||||
op := "and"
|
||||
if dsl.Op == searchstore.OpOr {
|
||||
op = "or"
|
||||
}
|
||||
filter = map[string]any{
|
||||
"op": op,
|
||||
"field": dsl.Field,
|
||||
"conds": conds,
|
||||
}
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
290
backend/infra/impl/document/searchstore/vikingdb/vk_test.go
Normal file
290
backend/infra/impl/document/searchstore/vikingdb/vk_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestVikingEmbeddingIntegration(t *testing.T) {
|
||||
if os.Getenv("ENABLE_VIKINGDB_INTEGRATION_TEST") != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
svc := vikingdb.NewVikingDBService(
|
||||
"api-vikingdb.volces.com",
|
||||
"cn-beijing",
|
||||
os.Getenv("VIKING_DB_AK"),
|
||||
os.Getenv("VIKING_DB_SK"),
|
||||
"https",
|
||||
)
|
||||
|
||||
cfg := &ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil,
|
||||
EmbeddingConfig: &VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: true,
|
||||
EnableHybrid: false,
|
||||
ModelName: ModelNameDoubaoEmbedding,
|
||||
ModelVersion: ModelNameDoubaoEmbedding.ModelVersion(),
|
||||
DenseWeight: nil,
|
||||
BuiltinEmbedding: nil,
|
||||
},
|
||||
}
|
||||
|
||||
mgr, err := NewManager(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collectionName := "test_coze_coll_1"
|
||||
|
||||
t.Run("create", func(t *testing.T) {
|
||||
err = mgr.Create(ctx, &searchstore.CreateRequest{
|
||||
CollectionName: collectionName,
|
||||
Fields: []*searchstore.Field{
|
||||
{
|
||||
Name: searchstore.FieldID,
|
||||
Type: searchstore.FieldTypeInt64,
|
||||
IsPrimary: true,
|
||||
},
|
||||
{
|
||||
Name: searchstore.FieldCreatorID,
|
||||
Type: searchstore.FieldTypeInt64,
|
||||
},
|
||||
{
|
||||
Name: "document_id",
|
||||
Type: searchstore.FieldTypeInt64,
|
||||
},
|
||||
{
|
||||
Name: searchstore.FieldTextContent,
|
||||
Type: searchstore.FieldTypeText,
|
||||
Indexing: true,
|
||||
},
|
||||
},
|
||||
CollectionMeta: nil,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("store", func(t *testing.T) {
|
||||
ss, err := mgr.GetSearchStore(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ids, err := ss.Store(ctx, []*schema.Document{
|
||||
{
|
||||
ID: "101",
|
||||
Content: "埃菲尔铁塔:位于法国巴黎,是世界上最著名的地标之一,由居斯塔夫・埃菲尔设计并建于 1889 年。",
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyCreatorID: int64(111),
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
"document_id": int64(567),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "102",
|
||||
Content: "长城:位于中国,是世界七大奇迹之一,从秦至明代修筑而成,全长超过 2 万公里",
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyCreatorID: int64(111),
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
"document_id": int64(567),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "103",
|
||||
Content: "罗马斗兽场:位于意大利罗马,于公元 70-80 年间建成,是古罗马帝国最大的圆形竞技场。",
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyCreatorID: int64(111),
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
"document_id": int64(568),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, searchstore.WithIndexingFields([]string{searchstore.FieldTextContent}))
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(ids)
|
||||
})
|
||||
|
||||
t.Run("retrieve", func(t *testing.T) {
|
||||
ss, err := mgr.GetSearchStore(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsl := &searchstore.DSL{
|
||||
Op: searchstore.OpIn,
|
||||
Field: "creator_id",
|
||||
Value: int64(111),
|
||||
}
|
||||
opts := []retriever.Option{
|
||||
searchstore.WithRetrieverPartitionKey("document_id"),
|
||||
searchstore.WithPartitions([]string{"567"}),
|
||||
retriever.WithDSLInfo(dsl.DSL()),
|
||||
}
|
||||
resp, err := ss.Retrieve(ctx, "旅游景点推荐", opts...)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(resp)
|
||||
})
|
||||
|
||||
t.Run("drop", func(t *testing.T) {
|
||||
assert.NoError(t, mgr.Drop(ctx, &searchstore.DropRequest{CollectionName: collectionName}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltinEmbeddingIntegration(t *testing.T) {
|
||||
if os.Getenv("ENABLE_VIKINGDB_INTEGRATION_TEST") != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
svc := vikingdb.NewVikingDBService(
|
||||
"api-vikingdb.volces.com",
|
||||
"cn-beijing",
|
||||
os.Getenv("VIKING_DB_AK"),
|
||||
os.Getenv("VIKING_DB_SK"),
|
||||
"https",
|
||||
)
|
||||
|
||||
embConfig := &openai.EmbeddingConfig{
|
||||
APIKey: os.Getenv("OPENAI_EMBEDDING_API_KEY"),
|
||||
ByAzure: true,
|
||||
BaseURL: os.Getenv("OPENAI_EMBEDDING_BASE_URL"),
|
||||
Model: os.Getenv("OPENAI_EMBEDDING_MODEL"),
|
||||
Dimensions: ptr.Of(1024),
|
||||
}
|
||||
emb, err := wrap.NewOpenAIEmbedder(ctx, embConfig, 1024)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cfg := &ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil,
|
||||
EmbeddingConfig: &VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: false,
|
||||
BuiltinEmbedding: emb,
|
||||
},
|
||||
}
|
||||
|
||||
mgr, err := NewManager(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collectionName := "test_coze_coll_2"
|
||||
|
||||
t.Run("create", func(t *testing.T) {
|
||||
err = mgr.Create(ctx, &searchstore.CreateRequest{
|
||||
CollectionName: collectionName,
|
||||
Fields: []*searchstore.Field{
|
||||
{
|
||||
Name: searchstore.FieldID,
|
||||
Type: searchstore.FieldTypeInt64,
|
||||
IsPrimary: true,
|
||||
},
|
||||
{
|
||||
Name: searchstore.FieldCreatorID,
|
||||
Type: searchstore.FieldTypeInt64,
|
||||
},
|
||||
{
|
||||
Name: "document_id",
|
||||
Type: searchstore.FieldTypeInt64,
|
||||
},
|
||||
{
|
||||
Name: searchstore.FieldTextContent,
|
||||
Type: searchstore.FieldTypeText,
|
||||
Indexing: true,
|
||||
},
|
||||
},
|
||||
CollectionMeta: nil,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("store", func(t *testing.T) {
|
||||
ss, err := mgr.GetSearchStore(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ids, err := ss.Store(ctx, []*schema.Document{
|
||||
{
|
||||
ID: "101",
|
||||
Content: "埃菲尔铁塔:位于法国巴黎,是世界上最著名的地标之一,由居斯塔夫・埃菲尔设计并建于 1889 年。",
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyCreatorID: int64(111),
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
"document_id": int64(567),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "102",
|
||||
Content: "长城:位于中国,是世界七大奇迹之一,从秦至明代修筑而成,全长超过 2 万公里",
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyCreatorID: int64(111),
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
"document_id": int64(567),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "103",
|
||||
Content: "罗马斗兽场:位于意大利罗马,于公元 70-80 年间建成,是古罗马帝国最大的圆形竞技场。",
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyCreatorID: int64(111),
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
"document_id": int64(568),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, searchstore.WithIndexingFields([]string{searchstore.FieldTextContent}))
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(ids)
|
||||
})
|
||||
|
||||
t.Run("retrieve", func(t *testing.T) {
|
||||
ss, err := mgr.GetSearchStore(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsl := &searchstore.DSL{
|
||||
Op: searchstore.OpIn,
|
||||
Field: "creator_id",
|
||||
Value: int64(111),
|
||||
}
|
||||
opts := []retriever.Option{
|
||||
searchstore.WithRetrieverPartitionKey("document_id"),
|
||||
searchstore.WithPartitions([]string{"567"}),
|
||||
retriever.WithDSLInfo(dsl.DSL()),
|
||||
}
|
||||
resp, err := ss.Retrieve(ctx, "旅游景点推荐", opts...)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(resp)
|
||||
})
|
||||
|
||||
t.Run("drop", func(t *testing.T) {
|
||||
assert.NoError(t, mgr.Drop(ctx, &searchstore.DropRequest{CollectionName: collectionName}))
|
||||
})
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user