From ff00dcb31b1d4fe33477eb21420134011cebc2d2 Mon Sep 17 00:00:00 2001 From: Ryo Date: Fri, 15 Aug 2025 10:46:09 +0800 Subject: [PATCH] refactor(knowledge): Move the searchstore manager to app infra (#764) --- backend/application/application.go | 20 +- .../application/base/appinfra/app_infra.go | 288 +++++++++++++++++- backend/application/knowledge/init.go | 285 +---------------- .../searchstore/vikingdb/vikingdb_manager.go | 4 + backend/infra/impl/embedding/ark/ark.go | 9 + docker/.env.example | 1 - 6 files changed, 312 insertions(+), 295 deletions(-) diff --git a/backend/application/application.go b/backend/application/application.go index 74f6a697..d534e0d7 100644 --- a/backend/application/application.go +++ b/backend/application/application.go @@ -128,6 +128,7 @@ func Init(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("Init - initVitalServices failed, err: %v", err) } + crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC)) crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC)) crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC)) @@ -254,16 +255,15 @@ func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents { func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents { return &knowledge.ServiceComponents{ - DB: b.infra.DB, - IDGenSVC: b.infra.IDGenSVC, - Storage: b.infra.TOSClient, - RDB: memoryService.RDBDomainSVC, - ImageX: b.infra.ImageXClient, - ES: b.infra.ESClient, - EventBus: b.eventbus.resourceEventBus, - CacheCli: b.infra.CacheCli, - OCR: b.infra.OCR, - ParserManager: b.infra.ParserManager, + DB: b.infra.DB, + IDGenSVC: b.infra.IDGenSVC, + Storage: b.infra.TOSClient, + RDB: memoryService.RDBDomainSVC, + SearchStoreManagers: b.infra.SearchStoreManagers, + EventBus: b.eventbus.resourceEventBus, + CacheCli: b.infra.CacheCli, + OCR: b.infra.OCR, + ParserManager: b.infra.ParserManager, } } diff --git a/backend/application/base/appinfra/app_infra.go b/backend/application/base/appinfra/app_infra.go index 0d78420f..4cb3f073 100644 --- a/backend/application/base/appinfra/app_infra.go +++ b/backend/application/base/appinfra/app_infra.go @@ -23,9 +23,13 @@ import ( "os" "strconv" "strings" + "time" "gorm.io/gorm" + "github.com/cloudwego/eino-ext/components/embedding/ollama" + "github.com/cloudwego/eino-ext/components/embedding/openai" + "github.com/milvus-io/milvus/client/v2/milvusclient" "github.com/volcengine/volc-sdk-golang/service/visual" "github.com/coze-dev/coze-studio/backend/application/internal" @@ -34,6 +38,8 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" "github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" + "github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore" + "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" "github.com/coze-dev/coze-studio/backend/infra/contract/imagex" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" @@ -41,14 +47,22 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox" "github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr" "github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr" - builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin" + "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin" "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure" + "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch" + "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus" + "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb" + "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark" + embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http" + "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap" "github.com/coze-dev/coze-studio/backend/infra/impl/es" "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" "github.com/coze-dev/coze-studio/backend/infra/impl/idgen" "github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex" "github.com/coze-dev/coze-studio/backend/infra/impl/mysql" "github.com/coze-dev/coze-studio/backend/infra/impl/storage" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" ) @@ -66,6 +80,7 @@ type AppDependencies struct { CodeRunner coderunner.Runner OCR ocr.OCR ParserManager parser.Manager + SearchStoreManagers []searchstore.Manager } func Init(ctx context.Context) (*AppDependencies, error) { @@ -122,15 +137,37 @@ func Init(ctx context.Context) (*AppDependencies, error) { if err != nil { return nil, err } - deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel) + + deps.ParserManager = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel) + + deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient) + if err != nil { + return nil, err + } + + deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient) + if err != nil { + return nil, err + } return deps, nil } +func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) { + // es full text search + esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es}) + + // vector search + mgr, err := getVectorStore(ctx) + if err != nil { + return nil, fmt.Errorf("init vector store failed, err=%w", err) + } + + return []searchstore.Manager{esSearchstoreManager, mgr}, nil +} + func initImageX(ctx context.Context) (imagex.ImageX, error) { - uploadComponentType := os.Getenv(consts.FileUploadComponentType) - if uploadComponentType != consts.FileUploadComponentTypeImagex { return storage.NewImagex(ctx) } @@ -230,12 +267,10 @@ func initOCR() ocr.OCR { return ocr } -func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) { +func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager { var parserManager parser.Manager parserType := os.Getenv(consts.ParserType) switch parserType { - case "builtin": - parserManager = builtinParser.NewManager(storage, ocr, imageAnnotationModel) case "paddleocr": url := os.Getenv(consts.PPStructureAPIURL) client := &http.Client{} @@ -245,8 +280,243 @@ func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationMode } parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel) default: - return nil, fmt.Errorf("unexpected document parser type, type=%s", parserType) + parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel) } - return parserManager, nil + return parserManager +} + +func getVectorStore(ctx context.Context) (searchstore.Manager, error) { + vsType := os.Getenv("VECTOR_STORE_TYPE") + + switch vsType { + case "milvus": + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + milvusAddr := os.Getenv("MILVUS_ADDR") + user := os.Getenv("MILVUS_USER") + password := os.Getenv("MILVUS_PASSWORD") + mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + Username: user, + Password: password, + }) + if err != nil { + return nil, fmt.Errorf("init milvus client failed, err=%w", err) + } + + emb, err := getEmbedding(ctx) + if err != nil { + return nil, fmt.Errorf("init milvus embedding failed, err=%w", err) + } + + mgr, err := milvus.NewManager(&milvus.ManagerConfig{ + Client: mc, + Embedding: emb, + EnableHybrid: ptr.Of(true), + }) + if err != nil { + return nil, fmt.Errorf("init milvus vector store failed, err=%w", err) + } + + return mgr, nil + case "vikingdb": + var ( + host = os.Getenv("VIKING_DB_HOST") + region = os.Getenv("VIKING_DB_REGION") + ak = os.Getenv("VIKING_DB_AK") + sk = os.Getenv("VIKING_DB_SK") + scheme = os.Getenv("VIKING_DB_SCHEME") + modelName = os.Getenv("VIKING_DB_MODEL_NAME") + ) + if ak == "" || sk == "" { + return nil, fmt.Errorf("invalid vikingdb ak / sk") + } + if host == "" { + host = "api-vikingdb.volces.com" + } + if region == "" { + region = "cn-beijing" + } + if scheme == "" { + scheme = "https" + } + + var embConfig *vikingdb.VikingEmbeddingConfig + if modelName != "" { + embName := vikingdb.VikingEmbeddingModelName(modelName) + if embName.Dimensions() == 0 { + return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName) + } + embConfig = &vikingdb.VikingEmbeddingConfig{ + UseVikingEmbedding: true, + EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse, + ModelName: embName, + ModelVersion: embName.ModelVersion(), + DenseWeight: ptr.Of(0.2), + BuiltinEmbedding: nil, + } + } else { + builtinEmbedding, err := getEmbedding(ctx) + if err != nil { + return nil, fmt.Errorf("builtint embedding init failed, err=%w", err) + } + + embConfig = &vikingdb.VikingEmbeddingConfig{ + UseVikingEmbedding: false, + EnableHybrid: false, + BuiltinEmbedding: builtinEmbedding, + } + } + + svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme) + mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{ + Service: svc, + IndexingConfig: nil, // use default config + EmbeddingConfig: embConfig, + }) + if err != nil { + return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err) + } + + return mgr, nil + + default: + return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType) + } +} + +func getEmbedding(ctx context.Context) (embedding.Embedder, error) { + var batchSize int + if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil { + logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100") + batchSize = 100 + } else { + batchSize = int(bs) + } + + var emb embedding.Embedder + + switch os.Getenv("EMBEDDING_TYPE") { + case "openai": + var ( + openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL") + openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL") + openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY") + openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE") + openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION") + openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS") + openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS") + ) + + byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure) + if err != nil { + return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err) + } + + dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err) + } + + openAICfg := &openai.EmbeddingConfig{ + APIKey: openAIEmbeddingApiKey, + ByAzure: byAzure, + BaseURL: openAIEmbeddingBaseURL, + APIVersion: openAIEmbeddingApiVersion, + Model: openAIEmbeddingModel, + // Dimensions: ptr.Of(int(dims)), + } + reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0) + if reqDims > 0 { + // some openai model not support request dims + openAICfg.Dimensions = ptr.Of(int(reqDims)) + } + + emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init openai embedding failed, err=%w", err) + } + + case "ark": + var ( + arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL") + arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL") + arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY") + // deprecated: use ARK_EMBEDDING_API_KEY instead + // ARK_EMBEDDING_AK will be removed in the future + arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK") + arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS") + arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE") + ) + + dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err) + } + + apiType := ark.APITypeText + if arkEmbeddingAPIType != "" { + if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal { + return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t) + } else { + apiType = t + } + } + + emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{ + APIKey: func() string { + if arkEmbeddingApiKey != "" { + return arkEmbeddingApiKey + } + return arkEmbeddingAK + }(), + Model: arkEmbeddingModel, + BaseURL: arkEmbeddingBaseURL, + APIType: &apiType, + }, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) + } + + case "ollama": + var ( + ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL") + ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL") + ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS") + ) + + dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err) + } + + emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{ + BaseURL: ollamaEmbeddingBaseURL, + Model: ollamaEmbeddingModel, + }, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) + } + + case "http": + var ( + httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR") + httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS") + ) + dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init http embedding dims failed, err=%w", err) + } + emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize) + if err != nil { + return nil, fmt.Errorf("init http embedding failed, err=%w", err) + } + + default: + return nil, fmt.Errorf("init knowledge embedding failed, type not configured") + } + + return emb, nil } diff --git a/backend/application/knowledge/init.go b/backend/application/knowledge/init.go index fe3f2153..8397741a 100644 --- a/backend/application/knowledge/init.go +++ b/backend/application/knowledge/init.go @@ -22,16 +22,9 @@ import ( "fmt" "os" "path/filepath" - "strconv" - "time" - "github.com/cloudwego/eino-ext/components/embedding/ark" - ollamaEmb "github.com/cloudwego/eino-ext/components/embedding/ollama" - "github.com/cloudwego/eino-ext/components/embedding/openai" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/schema" - "github.com/milvus-io/milvus/client/v2/milvusclient" - "github.com/volcengine/volc-sdk-golang/service/vikingdb" "gorm.io/gorm" "github.com/coze-dev/coze-studio/backend/application/internal" @@ -42,41 +35,29 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" "github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore" - "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" - "github.com/coze-dev/coze-studio/backend/infra/contract/es" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" - "github.com/coze-dev/coze-studio/backend/infra/contract/imagex" "github.com/coze-dev/coze-studio/backend/infra/contract/messages2query" "github.com/coze-dev/coze-studio/backend/infra/contract/rdb" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel" builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin" "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf" - sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch" - ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus" - ssvikingdb "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb" - arkemb "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark" - "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http" - "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap" "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin" - "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" ) type ServiceComponents struct { - DB *gorm.DB - IDGenSVC idgen.IDGenerator - Storage storage.Storage - RDB rdb.RDB - ImageX imagex.ImageX - ES es.Client - EventBus search.ResourceEventBus - CacheCli cache.Cmdable - OCR ocr.OCR - ParserManager parser.Manager + DB *gorm.DB + IDGenSVC idgen.IDGenerator + Storage storage.Storage + RDB rdb.RDB + EventBus search.ResourceEventBus + CacheCli cache.Cmdable + OCR ocr.OCR + ParserManager parser.Manager + SearchStoreManagers []searchstore.Manager } func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) { @@ -89,18 +70,6 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) { return nil, fmt.Errorf("init knowledge producer failed, err=%w", err) } - var sManagers []searchstore.Manager - - // es full text search - sManagers = append(sManagers, sses.NewManager(&sses.ManagerConfig{Client: c.ES})) - - // vector search - mgr, err := getVectorStore(ctx) - if err != nil { - return nil, fmt.Errorf("init vector store failed, err=%w", err) - } - sManagers = append(sManagers, mgr) - root, err := os.Getwd() if err != nil { logs.Warnf("[InitConfig] Failed to get current working directory: %v", err) @@ -142,7 +111,7 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) { IDGen: c.IDGenSVC, RDB: c.RDB, Producer: knowledgeProducer, - SearchStoreManagers: sManagers, + SearchStoreManagers: c.SearchStoreManagers, ParseManager: c.ParserManager, Storage: c.Storage, Rewriter: rewriter, @@ -163,240 +132,6 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) { return KnowledgeSVC, nil } -func getVectorStore(ctx context.Context) (searchstore.Manager, error) { - vsType := os.Getenv("VECTOR_STORE_TYPE") - - switch vsType { - case "milvus": - cctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - milvusAddr := os.Getenv("MILVUS_ADDR") - user := os.Getenv("MILVUS_USER") - password := os.Getenv("MILVUS_PASSWORD") - mc, err := milvusclient.New(cctx, &milvusclient.ClientConfig{ - Address: milvusAddr, - Username: user, - Password: password, - }) - if err != nil { - return nil, fmt.Errorf("init milvus client failed, err=%w", err) - } - - emb, err := getEmbedding(ctx) - if err != nil { - return nil, fmt.Errorf("init milvus embedding failed, err=%w", err) - } - - mgr, err := ssmilvus.NewManager(&ssmilvus.ManagerConfig{ - Client: mc, - Embedding: emb, - EnableHybrid: ptr.Of(true), - }) - if err != nil { - return nil, fmt.Errorf("init milvus vector store failed, err=%w", err) - } - - return mgr, nil - case "vikingdb": - var ( - host = os.Getenv("VIKING_DB_HOST") - region = os.Getenv("VIKING_DB_REGION") - ak = os.Getenv("VIKING_DB_AK") - sk = os.Getenv("VIKING_DB_SK") - scheme = os.Getenv("VIKING_DB_SCHEME") - modelName = os.Getenv("VIKING_DB_MODEL_NAME") - ) - if ak == "" || sk == "" { - return nil, fmt.Errorf("invalid vikingdb ak / sk") - } - if host == "" { - host = "api-vikingdb.volces.com" - } - if region == "" { - region = "cn-beijing" - } - if scheme == "" { - scheme = "https" - } - - var embConfig *ssvikingdb.VikingEmbeddingConfig - if modelName != "" { - embName := ssvikingdb.VikingEmbeddingModelName(modelName) - if embName.Dimensions() == 0 { - return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName) - } - embConfig = &ssvikingdb.VikingEmbeddingConfig{ - UseVikingEmbedding: true, - EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse, - ModelName: embName, - ModelVersion: embName.ModelVersion(), - DenseWeight: ptr.Of(0.2), - BuiltinEmbedding: nil, - } - } else { - builtinEmbedding, err := getEmbedding(ctx) - if err != nil { - return nil, fmt.Errorf("builtint embedding init failed, err=%w", err) - } - - embConfig = &ssvikingdb.VikingEmbeddingConfig{ - UseVikingEmbedding: false, - EnableHybrid: false, - BuiltinEmbedding: builtinEmbedding, - } - } - svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme) - mgr, err := ssvikingdb.NewManager(&ssvikingdb.ManagerConfig{ - Service: svc, - IndexingConfig: nil, // use default config - EmbeddingConfig: embConfig, - }) - if err != nil { - return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err) - } - - return mgr, nil - - default: - return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType) - } -} - -func getEmbedding(ctx context.Context) (embedding.Embedder, error) { - var batchSize int - if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil { - logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100") - batchSize = 100 - } else { - batchSize = int(bs) - } - - var emb embedding.Embedder - - switch os.Getenv("EMBEDDING_TYPE") { - case "openai": - var ( - openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL") - openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL") - openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY") - openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE") - openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION") - openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS") - openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS") - ) - - byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure) - if err != nil { - return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err) - } - - dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err) - } - - openAICfg := &openai.EmbeddingConfig{ - APIKey: openAIEmbeddingApiKey, - ByAzure: byAzure, - BaseURL: openAIEmbeddingBaseURL, - APIVersion: openAIEmbeddingApiVersion, - Model: openAIEmbeddingModel, - // Dimensions: ptr.Of(int(dims)), - } - reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0) - if reqDims > 0 { - // some openai model not support request dims - openAICfg.Dimensions = ptr.Of(int(reqDims)) - } - - emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init openai embedding failed, err=%w", err) - } - - case "ark": - var ( - arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL") - arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL") - arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY") - // deprecated: use ARK_EMBEDDING_API_KEY instead - // ARK_EMBEDDING_AK will be removed in the future - arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK") - arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS") - arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE") - ) - - dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err) - } - - apiType := ark.APITypeText - if arkEmbeddingAPIType != "" { - if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal { - return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t) - } else { - apiType = t - } - } - - emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{ - APIKey: func() string { - if arkEmbeddingApiKey != "" { - return arkEmbeddingApiKey - } - return arkEmbeddingAK - }(), - Model: arkEmbeddingModel, - BaseURL: arkEmbeddingBaseURL, - APIType: &apiType, - }, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) - } - - case "ollama": - var ( - ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL") - ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL") - ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS") - ) - - dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err) - } - - emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{ - BaseURL: ollamaEmbeddingBaseURL, - Model: ollamaEmbeddingModel, - }, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) - } - - case "http": - var ( - httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR") - httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS") - ) - dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64) - if err != nil { - return nil, fmt.Errorf("init http embedding dims failed, err=%w", err) - } - emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize) - if err != nil { - return nil, fmt.Errorf("init http embedding failed, err=%w", err) - } - - default: - return nil, fmt.Errorf("init knowledge embedding failed, type not configured") - } - - return emb, nil -} - func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) { b, err := os.ReadFile(jsonFilePath) if err != nil { diff --git a/backend/infra/impl/document/searchstore/vikingdb/vikingdb_manager.go b/backend/infra/impl/document/searchstore/vikingdb/vikingdb_manager.go index 1d8767a5..ae9fa159 100644 --- a/backend/infra/impl/document/searchstore/vikingdb/vikingdb_manager.go +++ b/backend/infra/impl/document/searchstore/vikingdb/vikingdb_manager.go @@ -65,6 +65,10 @@ type VikingEmbeddingConfig struct { BuiltinEmbedding embedding.Embedder } +func NewVikingDBService(host string, region string, ak string, sk string, scheme string) *vikingdb.VikingDBService { + return vikingdb.NewVikingDBService(host, region, ak, sk, scheme) +} + func NewManager(config *ManagerConfig) (searchstore.Manager, error) { if config.Service == nil { return nil, fmt.Errorf("[NewManager] vikingdb service is nil") diff --git a/backend/infra/impl/embedding/ark/ark.go b/backend/infra/impl/embedding/ark/ark.go index 02976998..0699f2ed 100644 --- a/backend/infra/impl/embedding/ark/ark.go +++ b/backend/infra/impl/embedding/ark/ark.go @@ -33,6 +33,15 @@ import ( "github.com/coze-dev/coze-studio/backend/types/errno" ) +type EmbeddingConfig = ark.EmbeddingConfig + +type APIType = ark.APIType + +const ( + APITypeText = ark.APITypeText + APITypeMultiModal APIType = ark.APITypeMultiModal +) + func NewArkEmbedder(ctx context.Context, config *ark.EmbeddingConfig, dimensions int64, batchSize int) (contract.Embedder, error) { emb, err := ark.NewEmbedder(ctx, config) if err != nil { diff --git a/docker/.env.example b/docker/.env.example index ea652d0b..3f739cf3 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -3,7 +3,6 @@ export LISTEN_ADDR=":8888" export LOG_LEVEL="debug" export MAX_REQUEST_BODY_SIZE=1073741824 export SERVER_HOST="http://localhost${LISTEN_ADDR}" -export MINIO_PROXY_ENDPOINT="" export USE_SSL="0" export SSL_CERT_FILE="" export SSL_KEY_FILE=""