feat: support EMBEDDING_MAX_BATCH_SIZE (#311)

This commit is contained in:
N3ko
2025-07-30 15:31:54 +08:00
committed by GitHub
parent f93f26fc48
commit bb74272385
9 changed files with 85 additions and 71 deletions

View File

@@ -280,6 +280,14 @@ func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
}
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
var batchSize int
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
batchSize = 100
} else {
batchSize = int(bs)
}
var emb embedding.Embedder
switch os.Getenv("EMBEDDING_TYPE") {
@@ -318,7 +326,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
openAICfg.Dimensions = ptr.Of(int(reqDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims)
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
@@ -340,7 +348,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
APIKey: arkEmbeddingAK,
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
}, dims)
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
@@ -360,7 +368,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{
BaseURL: ollamaEmbeddingBaseURL,
Model: ollamaEmbeddingModel,
}, dims)
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
@@ -374,7 +382,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
if err != nil {
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
}
emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims)
emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}