feat: intergrate gemini embedding (#783)
This commit is contained in:
@@ -25,8 +25,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/genai"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
@@ -499,7 +501,61 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
case "gemini":
|
||||
var (
|
||||
geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL")
|
||||
geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL")
|
||||
geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY")
|
||||
geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS")
|
||||
geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI
|
||||
geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT")
|
||||
geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION")
|
||||
)
|
||||
|
||||
if len(geminiEmbeddingModel) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingApiKey) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingDims) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingBackend) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required")
|
||||
}
|
||||
|
||||
dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr)
|
||||
}
|
||||
|
||||
backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr)
|
||||
}
|
||||
|
||||
geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: geminiEmbeddingApiKey,
|
||||
Backend: genai.Backend(backend),
|
||||
Project: geminiEmbeddingProject,
|
||||
Location: geminiEmbeddingLocation,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: geminiEmbeddingBaseURL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
||||
Client: geminiCli,
|
||||
Model: geminiEmbeddingModel,
|
||||
OutputDimensionality: ptr.Of(int32(dims)),
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
||||
}
|
||||
case "http":
|
||||
var (
|
||||
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
|
||||
|
||||
Reference in New Issue
Block a user