feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,122 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"fmt"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
embcontract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type VikingEmbeddingModelName string
const (
ModelNameDoubaoEmbedding VikingEmbeddingModelName = "doubao-embedding"
ModelNameDoubaoEmbeddingLarge VikingEmbeddingModelName = "doubao-embedding-large"
ModelNameDoubaoEmbeddingVision VikingEmbeddingModelName = "doubao-embedding-vision"
ModelNameBGELargeZH VikingEmbeddingModelName = "bge-large-zh"
ModelNameBGEM3 VikingEmbeddingModelName = "bge-m3"
ModelNameBGEVisualizedM3 VikingEmbeddingModelName = "bge-visualized-m3"
//ModelNameDoubaoEmbeddingAndM3 VikingEmbeddingModelName = "doubao-embedding-and-m3"
//ModelNameDoubaoEmbeddingLargeAndM3 VikingEmbeddingModelName = "doubao-embedding-large-and-m3"
//ModelNameBGELargeZHAndM3 VikingEmbeddingModelName = "bge-large-zh-and-m3"
)
func (v VikingEmbeddingModelName) Dimensions() int64 {
switch v {
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingVision:
return 2048
case ModelNameDoubaoEmbeddingLarge:
return 4096
case ModelNameBGELargeZH, ModelNameBGEM3, ModelNameBGEVisualizedM3:
return 1024
default:
return 0
}
}
func (v VikingEmbeddingModelName) ModelVersion() *string {
switch v {
case ModelNameDoubaoEmbedding:
return ptr.Of("240515")
case ModelNameDoubaoEmbeddingLarge:
return ptr.Of("240915")
case ModelNameDoubaoEmbeddingVision:
return ptr.Of("250328")
default:
return nil
}
}
func (v VikingEmbeddingModelName) SupportStatus() embcontract.SupportStatus {
switch v {
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingLarge, ModelNameDoubaoEmbeddingVision, ModelNameBGELargeZH, ModelNameBGEVisualizedM3:
return embcontract.SupportDense
case ModelNameBGEM3:
return embcontract.SupportDenseAndSparse
default:
return embcontract.SupportDense
}
}
type IndexType string
const (
IndexTypeHNSW IndexType = vikingdb.HNSW
IndexTypeHNSWHybrid IndexType = vikingdb.HNSW_HYBRID
IndexTypeFlat IndexType = vikingdb.FLAT
IndexTypeIVF IndexType = vikingdb.IVF
IndexTypeDiskANN IndexType = vikingdb.DiskANN
)
type IndexDistance string
const (
IndexDistanceIP IndexDistance = vikingdb.IP
IndexDistanceL2 IndexDistance = vikingdb.L2
IndexDistanceCosine IndexDistance = vikingdb.COSINE
)
type IndexQuant string
const (
IndexQuantInt8 IndexQuant = vikingdb.Int8
IndexQuantFloat IndexQuant = vikingdb.Float
IndexQuantFix16 IndexQuant = vikingdb.Fix16
IndexQuantPQ IndexQuant = vikingdb.PQ
)
const (
vikingEmbeddingUseDense = "return_dense"
vikingEmbeddingUseSparse = "return_sparse"
vikingEmbeddingRespSentenceDense = "sentence_dense_embedding"
vikingEmbeddingRespSentenceSparse = "sentence_sparse_embedding"
vikingIndexName = "opencoze_index"
)
const (
errCollectionNotFound = "collection not found"
errIndexNotFound = "index not found"
)
func denseFieldName(name string) string {
return fmt.Sprintf("dense_%s", name)
}

View File

@@ -0,0 +1,331 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"fmt"
"strings"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ManagerConfig struct {
Service *vikingdb.VikingDBService
IndexingConfig *VikingIndexingConfig
EmbeddingConfig *VikingEmbeddingConfig
// TODO: cache viking collection & index client
}
type VikingIndexingConfig struct {
// vector index config
Type IndexType // default: hnsw / hnsw_hybrid
Distance *IndexDistance // default: ip
Quant *IndexQuant // default: int8
HnswM *int64 // default: 20
HnswCef *int64 // default: 400
HnswSef *int64 // default: 800
// others
CpuQuota int64 // default: 2
ShardCount int64 // default: 1
}
type VikingEmbeddingConfig struct {
UseVikingEmbedding bool
EnableHybrid bool
// viking embedding config
ModelName VikingEmbeddingModelName
ModelVersion *string
DenseWeight *float64
// builtin embedding config
BuiltinEmbedding embedding.Embedder
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Service == nil {
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
}
if config.EmbeddingConfig == nil {
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
}
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
}
if config.EmbeddingConfig.UseVikingEmbedding &&
config.EmbeddingConfig.EnableHybrid &&
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
}
if config.IndexingConfig == nil {
config.IndexingConfig = &VikingIndexingConfig{}
}
if config.IndexingConfig.Type == "" {
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
config.IndexingConfig.Type = IndexTypeHNSW
} else {
config.IndexingConfig.Type = IndexTypeHNSWHybrid
}
}
if config.IndexingConfig.Distance == nil {
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
}
if config.IndexingConfig.Quant == nil {
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
}
if config.IndexingConfig.HnswM == nil {
config.IndexingConfig.HnswM = ptr.Of(int64(20))
}
if config.IndexingConfig.HnswCef == nil {
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
}
if config.IndexingConfig.HnswSef == nil {
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
}
if config.IndexingConfig.CpuQuota == 0 {
config.IndexingConfig.CpuQuota = 2
}
if config.IndexingConfig.ShardCount == 0 {
config.IndexingConfig.ShardCount = 1
}
return &manager{
config: config,
}, nil
}
type manager struct {
config *ManagerConfig
}
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return err
}
if err := m.createIndex(ctx, req); err != nil {
return err
}
return nil
}
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
}
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
}
return nil
}
func (m *manager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
collection, err := m.config.Service.GetCollection(collectionName)
if err != nil {
return nil, err
}
return &vkSearchStore{manager: m, collection: collection}, nil
}
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
collection, err := svc.GetCollection(req.CollectionName)
if err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
} else if collection != nil {
return nil
}
fields, vopts, err := m.mapFields(req.Fields)
if err != nil {
return err
}
if vopts != nil {
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
} else {
_, err = svc.CreateCollection(req.CollectionName, fields, "")
}
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
return nil
}
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
if err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
} else if index != nil {
return nil
}
vectorIndex := &vikingdb.VectorIndexParams{
IndexType: string(m.config.IndexingConfig.Type),
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
}
opts := vikingdb.NewIndexOptions().
SetVectorIndex(vectorIndex).
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
SetShardCount(m.config.IndexingConfig.ShardCount)
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
return nil
}
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
var (
foundID bool
foundCreatorID bool
dstFields = make([]vikingdb.Field, 0, len(srcFields))
vectorizeOpts []*vikingdb.VectorizeTuple
embConfig = m.config.EmbeddingConfig
)
for _, srcField := range srcFields {
switch srcField.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if srcField.Indexing {
if srcField.Type != searchstore.FieldTypeText {
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
}
if embConfig.UseVikingEmbedding {
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name))
if embConfig.EnableHybrid {
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name))
}
vectorizeOpts = append(vectorizeOpts, vt)
} else {
dstFields = append(dstFields, vikingdb.Field{
FieldName: denseFieldName(srcField.Name),
FieldType: vikingdb.Vector,
DefaultVal: nil,
Dim: m.getDims(),
})
}
}
dstField := vikingdb.Field{
FieldName: srcField.Name,
IsPrimaryKey: srcField.IsPrimary,
}
switch srcField.Type {
case searchstore.FieldTypeInt64:
dstField.FieldType = vikingdb.Int64
case searchstore.FieldTypeText:
dstField.FieldType = vikingdb.Text
case searchstore.FieldTypeDenseVector:
dstField.FieldType = vikingdb.Vector
dstField.Dim = m.getDims()
case searchstore.FieldTypeSparseVector:
dstField.FieldType = vikingdb.Sparse_Vector
default:
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
}
dstFields = append(dstFields, dstField)
}
if !foundID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldID,
FieldType: vikingdb.Int64,
IsPrimaryKey: true,
})
}
if !foundCreatorID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldCreatorID,
FieldType: vikingdb.Int64,
})
}
return dstFields, vectorizeOpts, nil
}
func (m *manager) newVectorizeModelConf(fieldName string) *vikingdb.VectorizeModelConf {
embConfig := m.config.EmbeddingConfig
vmc := vikingdb.NewVectorizeModelConf().
SetTextField(fieldName).
SetModelName(string(embConfig.ModelName)).
SetDim(m.getDims())
if embConfig.ModelVersion != nil {
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
}
return vmc
}
func (m *manager) getDims() int64 {
if m.config.EmbeddingConfig.UseVikingEmbedding {
return m.config.EmbeddingConfig.ModelName.Dimensions()
}
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
}

View File

@@ -0,0 +1,388 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type vkSearchStore struct {
*manager
collection *vikingdb.Collection
index *vikingdb.Index
}
func (v *vkSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
if len(docs) == 0 {
return nil, nil
}
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
_ = implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
docsWithoutVector, err := slices.TransformWithErrorCheck(docs, v.document2DataWithoutVector)
if err != nil {
return nil, fmt.Errorf("[Store] vikingdb failed to transform documents, %w", err)
}
indexingFields := sets.FromSlice(implSpecOptions.IndexingFields)
for _, part := range slices.Chunks(docsWithoutVector, 100) {
docsWithVector, err := v.addEmbedding(ctx, part, indexingFields)
if err != nil {
return nil, err
}
if err := v.collection.UpsertData(docsWithVector); err != nil {
return nil, err
}
}
ids = slices.Transform(docs, func(a *schema.Document) string { return a.ID })
return
}
func (v *vkSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (docs []*schema.Document, err error) {
indexClient := v.index
if indexClient == nil {
foundIndex := false
for _, index := range v.collection.Indexes {
if index.IndexName == vikingIndexName {
foundIndex = true
break
}
}
if !foundIndex {
return nil, fmt.Errorf("[Retrieve] vikingdb index not found, name=%s", vikingIndexName)
}
indexClient, err = v.config.Service.GetIndex(v.collection.CollectionName, vikingIndexName)
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb failed to get index, %w", err)
}
}
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(4)}, opts...)
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
searchOpts := vikingdb.NewSearchOptions().
SetLimit(int64(ptr.From(options.TopK))).
SetText(query).
SetRetry(true)
filter, err := v.genFilter(ctx, options, implSpecOptions)
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb failed to build filter, %w", err)
}
if filter != nil {
// 不支持跨 partition 召回,使用 filter 替代
searchOpts = searchOpts.SetFilter(filter)
}
var data []*vikingdb.Data
if v.config.EmbeddingConfig.UseVikingEmbedding {
data, err = indexClient.SearchWithMultiModal(searchOpts)
} else {
var dense [][]float64
dense, err = v.config.EmbeddingConfig.BuiltinEmbedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] embed failed, %w", err)
}
if len(dense) != 1 {
return nil, fmt.Errorf("[Retrieve] unexpected dense vector size, expected=1, got=%d", len(dense))
}
data, err = indexClient.SearchByVector(dense[0], searchOpts)
}
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb search failed, %w", err)
}
docs, err = v.parseSearchResult(data)
if err != nil {
return nil, err
}
return
}
func (v *vkSearchStore) Delete(ctx context.Context, ids []string) error {
for _, part := range slices.Chunks(ids, 100) {
if err := v.collection.DeleteData(part); err != nil {
return err
}
}
return nil
}
func (v *vkSearchStore) document2DataWithoutVector(doc *schema.Document) (data vikingdb.Data, err error) {
creatorID, err := document.GetDocumentCreatorID(doc)
if err != nil {
return data, err
}
docID, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
return data, err
}
fields := map[string]interface{}{
searchstore.FieldID: docID,
searchstore.FieldCreatorID: creatorID,
searchstore.FieldTextContent: doc.Content,
}
if ext, err := document.GetDocumentExternalStorage(doc); err == nil { // try load
for key, val := range ext {
fields[key] = val
}
}
return vikingdb.Data{
Id: doc.ID,
Fields: fields,
}, nil
}
func (v *vkSearchStore) addEmbedding(ctx context.Context, rows []vikingdb.Data, indexingFields map[string]struct{}) ([]vikingdb.Data, error) {
if v.config.EmbeddingConfig.UseVikingEmbedding {
return rows, nil
}
emb := v.config.EmbeddingConfig.BuiltinEmbedding
for indexingField := range indexingFields {
values := make([]string, len(rows))
for i, row := range rows {
val, found := row.Fields[indexingField]
if !found {
return nil, fmt.Errorf("[addEmbedding] indexing field not found in document, field=%s", indexingField)
}
strVal, ok := val.(string)
if !ok {
return nil, fmt.Errorf("[addEmbedding] val not string, field=%s, val=%v", indexingField, val)
}
values[i] = strVal
}
dense, err := emb.EmbedStrings(ctx, values)
if err != nil {
return nil, fmt.Errorf("[addEmbedding] failed to embed, %w", err)
}
if len(dense) != len(values) {
return nil, fmt.Errorf("[addEmbedding] unexpected dense vector size, expected=%d, got=%d", len(values), len(dense))
}
df := denseFieldName(indexingField)
for i := range dense {
rows[i].Fields[df] = dense[i]
}
}
return rows, nil
}
func (v *vkSearchStore) parseSearchResult(result []*vikingdb.Data) ([]*schema.Document, error) {
docs := make([]*schema.Document, 0, len(result))
for _, data := range result {
ext := make(map[string]any)
doc := document.WithDocumentExternalStorage(&schema.Document{MetaData: map[string]any{}}, ext).
WithScore(data.Score)
for field, val := range data.Fields {
switch field {
case searchstore.FieldID:
jn, ok := val.(json.Number)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] id type assertion failed, val=%v", val)
}
doc.ID = jn.String()
case searchstore.FieldCreatorID:
jn, ok := val.(json.Number)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] creator_id type assertion failed, val=%v", val)
}
creatorID, err := jn.Int64()
if err != nil {
return nil, fmt.Errorf("[parseSearchResult] creator_id value not int64, val=%v", jn.String())
}
doc = document.WithDocumentCreatorID(doc, creatorID)
case searchstore.FieldTextContent:
text, ok := val.(string)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] content value not string, val=%v", val)
}
doc.Content = text
default:
switch t := val.(type) {
case json.Number:
if i64, err := t.Int64(); err == nil {
ext[field] = i64
} else if f64, err := t.Float64(); err == nil {
ext[field] = f64
} else {
ext[field] = t.String()
}
default:
ext[field] = val
}
}
}
docs = append(docs, doc)
}
return docs, nil
}
func (v *vkSearchStore) genFilter(ctx context.Context, co *retriever.Options, ro *searchstore.RetrieverOptions) (map[string]any, error) {
filter, err := v.dsl2Filter(ctx, co.DSLInfo)
if err != nil {
return nil, err
}
if ro.PartitionKey != nil && len(ro.Partitions) > 0 {
var (
key = ptr.From(ro.PartitionKey)
fieldType = ""
conds any
)
for _, field := range v.collection.Fields {
if field.FieldName == key {
fieldType = field.FieldType
}
}
if fieldType == "" {
return nil, fmt.Errorf("[Retrieve] partition key not found, key=%s", key)
}
switch fieldType {
case vikingdb.Int64:
c := make([]int64, 0, len(ro.Partitions))
for _, item := range ro.Partitions {
i64, err := strconv.ParseInt(item, 10, 64)
if err != nil {
return nil, fmt.Errorf("[Retrieve] partition value parse error, key=%s, val=%v, err=%v", key, item, err)
}
c = append(c, i64)
}
conds = c
case vikingdb.String:
conds = ro.Partitions
default:
return nil, fmt.Errorf("[Retrieve] invalid field type for partition, key=%s, type=%s", key, fieldType)
}
op := map[string]any{"op": "must", "field": key, "conds": conds}
if filter != nil {
filter = op
} else {
filter = map[string]any{
"op": "and",
"conds": []map[string]any{op, filter},
}
}
}
return filter, nil
}
func (v *vkSearchStore) dsl2Filter(ctx context.Context, src map[string]any) (map[string]any, error) {
dsl, err := searchstore.LoadDSL(src)
if err != nil {
return nil, err
}
if dsl == nil {
return nil, nil
}
toSliceValue := func(val any) any {
if reflect.TypeOf(val).Kind() == reflect.Slice {
return val
}
return []any{val}
}
var filter map[string]any
switch dsl.Op {
case searchstore.OpEq, searchstore.OpIn:
filter = map[string]any{
"op": "must",
"field": dsl.Field,
"conds": toSliceValue(dsl.Value),
}
case searchstore.OpNe:
filter = map[string]any{
"op": "must_not",
"field": dsl.Field,
"conds": toSliceValue(dsl.Value),
}
case searchstore.OpLike:
logs.CtxWarnf(ctx, "[dsl2Filter] vikingdb invalid dsl type, skip, type=%s", dsl.Op)
case searchstore.OpAnd, searchstore.OpOr:
var conds []map[string]any
sub, ok := dsl.Value.([]map[string]any)
if !ok {
return nil, fmt.Errorf("[dsl2Filter] invalid value for and/or, should be []map[string]any")
}
for _, subDSL := range sub {
cond, err := v.dsl2Filter(ctx, subDSL)
if err != nil {
return nil, err
}
conds = append(conds, cond)
}
op := "and"
if dsl.Op == searchstore.OpOr {
op = "or"
}
filter = map[string]any{
"op": op,
"field": dsl.Field,
"conds": conds,
}
}
return filter, nil
}

View File

@@ -0,0 +1,290 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"fmt"
"os"
"testing"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestVikingEmbeddingIntegration(t *testing.T) {
if os.Getenv("ENABLE_VIKINGDB_INTEGRATION_TEST") != "true" {
return
}
ctx := context.Background()
svc := vikingdb.NewVikingDBService(
"api-vikingdb.volces.com",
"cn-beijing",
os.Getenv("VIKING_DB_AK"),
os.Getenv("VIKING_DB_SK"),
"https",
)
cfg := &ManagerConfig{
Service: svc,
IndexingConfig: nil,
EmbeddingConfig: &VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: false,
ModelName: ModelNameDoubaoEmbedding,
ModelVersion: ModelNameDoubaoEmbedding.ModelVersion(),
DenseWeight: nil,
BuiltinEmbedding: nil,
},
}
mgr, err := NewManager(cfg)
assert.NoError(t, err)
collectionName := "test_coze_coll_1"
t.Run("create", func(t *testing.T) {
err = mgr.Create(ctx, &searchstore.CreateRequest{
CollectionName: collectionName,
Fields: []*searchstore.Field{
{
Name: searchstore.FieldID,
Type: searchstore.FieldTypeInt64,
IsPrimary: true,
},
{
Name: searchstore.FieldCreatorID,
Type: searchstore.FieldTypeInt64,
},
{
Name: "document_id",
Type: searchstore.FieldTypeInt64,
},
{
Name: searchstore.FieldTextContent,
Type: searchstore.FieldTypeText,
Indexing: true,
},
},
CollectionMeta: nil,
})
assert.NoError(t, err)
})
t.Run("store", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
ids, err := ss.Store(ctx, []*schema.Document{
{
ID: "101",
Content: "埃菲尔铁塔:位于法国巴黎,是世界上最著名的地标之一,由居斯塔夫・埃菲尔设计并建于 1889 年。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "102",
Content: "长城:位于中国,是世界七大奇迹之一,从秦至明代修筑而成,全长超过 2 万公里",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "103",
Content: "罗马斗兽场:位于意大利罗马,于公元 70-80 年间建成,是古罗马帝国最大的圆形竞技场。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(568),
},
},
},
}, searchstore.WithIndexingFields([]string{searchstore.FieldTextContent}))
assert.NoError(t, err)
fmt.Println(ids)
})
t.Run("retrieve", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
dsl := &searchstore.DSL{
Op: searchstore.OpIn,
Field: "creator_id",
Value: int64(111),
}
opts := []retriever.Option{
searchstore.WithRetrieverPartitionKey("document_id"),
searchstore.WithPartitions([]string{"567"}),
retriever.WithDSLInfo(dsl.DSL()),
}
resp, err := ss.Retrieve(ctx, "旅游景点推荐", opts...)
assert.NoError(t, err)
fmt.Println(resp)
})
t.Run("drop", func(t *testing.T) {
assert.NoError(t, mgr.Drop(ctx, &searchstore.DropRequest{CollectionName: collectionName}))
})
}
func TestBuiltinEmbeddingIntegration(t *testing.T) {
if os.Getenv("ENABLE_VIKINGDB_INTEGRATION_TEST") != "true" {
return
}
ctx := context.Background()
svc := vikingdb.NewVikingDBService(
"api-vikingdb.volces.com",
"cn-beijing",
os.Getenv("VIKING_DB_AK"),
os.Getenv("VIKING_DB_SK"),
"https",
)
embConfig := &openai.EmbeddingConfig{
APIKey: os.Getenv("OPENAI_EMBEDDING_API_KEY"),
ByAzure: true,
BaseURL: os.Getenv("OPENAI_EMBEDDING_BASE_URL"),
Model: os.Getenv("OPENAI_EMBEDDING_MODEL"),
Dimensions: ptr.Of(1024),
}
emb, err := wrap.NewOpenAIEmbedder(ctx, embConfig, 1024)
assert.NoError(t, err)
cfg := &ManagerConfig{
Service: svc,
IndexingConfig: nil,
EmbeddingConfig: &VikingEmbeddingConfig{
UseVikingEmbedding: false,
BuiltinEmbedding: emb,
},
}
mgr, err := NewManager(cfg)
assert.NoError(t, err)
collectionName := "test_coze_coll_2"
t.Run("create", func(t *testing.T) {
err = mgr.Create(ctx, &searchstore.CreateRequest{
CollectionName: collectionName,
Fields: []*searchstore.Field{
{
Name: searchstore.FieldID,
Type: searchstore.FieldTypeInt64,
IsPrimary: true,
},
{
Name: searchstore.FieldCreatorID,
Type: searchstore.FieldTypeInt64,
},
{
Name: "document_id",
Type: searchstore.FieldTypeInt64,
},
{
Name: searchstore.FieldTextContent,
Type: searchstore.FieldTypeText,
Indexing: true,
},
},
CollectionMeta: nil,
})
assert.NoError(t, err)
})
t.Run("store", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
ids, err := ss.Store(ctx, []*schema.Document{
{
ID: "101",
Content: "埃菲尔铁塔:位于法国巴黎,是世界上最著名的地标之一,由居斯塔夫・埃菲尔设计并建于 1889 年。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "102",
Content: "长城:位于中国,是世界七大奇迹之一,从秦至明代修筑而成,全长超过 2 万公里",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "103",
Content: "罗马斗兽场:位于意大利罗马,于公元 70-80 年间建成,是古罗马帝国最大的圆形竞技场。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(568),
},
},
},
}, searchstore.WithIndexingFields([]string{searchstore.FieldTextContent}))
assert.NoError(t, err)
fmt.Println(ids)
})
t.Run("retrieve", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
dsl := &searchstore.DSL{
Op: searchstore.OpIn,
Field: "creator_id",
Value: int64(111),
}
opts := []retriever.Option{
searchstore.WithRetrieverPartitionKey("document_id"),
searchstore.WithPartitions([]string{"567"}),
retriever.WithDSLInfo(dsl.DSL()),
}
resp, err := ss.Retrieve(ctx, "旅游景点推荐", opts...)
assert.NoError(t, err)
fmt.Println(resp)
})
t.Run("drop", func(t *testing.T) {
assert.NoError(t, mgr.Drop(ctx, &searchstore.DropRequest{CollectionName: collectionName}))
})
}