feat: support ark embedding api_type (#519)

This commit is contained in:
N3ko
2025-08-04 18:06:38 +08:00
committed by GitHub
parent adc5986d13
commit 36923bd0a4
5 changed files with 61 additions and 46 deletions

View File

@@ -338,8 +338,9 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
// deprecated: use ARK_EMBEDDING_API_KEY instead
// ARK_EMBEDDING_AK will be removed in the future
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
)
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
@@ -347,6 +348,15 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
}
apiType := ark.APITypeText
if arkEmbeddingAPIType != "" {
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
} else {
apiType = t
}
}
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: func() string {
if arkEmbeddingApiKey != "" {
@@ -356,6 +366,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
}(),
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
APIType: &apiType,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)