Skip to content

Commit c7cceee

Browse files
bwook00Pouyanpi
andauthored
feat: add Google embedding integration (#1304)
Co-authored-by: Pouyanpi <[email protected]>
1 parent 4de5b2f commit c7cceee

File tree

7 files changed

+558
-1
lines changed

7 files changed

+558
-1
lines changed

docs/user-guides/configuration-guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ The following tables lists the supported embedding providers:
540540
| NVIDIA AI Endpoints | `nvidia_ai_endpoints` | `nv-embed-v1`, etc. |
541541
| AzureOpenAI | `AzureOpenAI` | `text-embedding-ada-002`, etc.
542542
| Cohere | `cohere` | `embed-multilingual-v3.0`, etc. |
543+
| Google Gemini | `google` | `gemini-embedding-001`, etc. |
543544

544545
```{note}
545546
You can use any of the supported models for any of the supported embedding providers.

nemoguardrails/embeddings/providers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import Optional, Type
2020

21-
from . import azureopenai, cohere, fastembed, nim, openai, sentence_transformers
21+
from . import azureopenai, cohere, fastembed, google, nim, openai, sentence_transformers
2222
from .base import EmbeddingModel
2323
from .registry import EmbeddingProviderRegistry
2424

@@ -69,6 +69,7 @@ def register_embedding_provider(
6969
register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel)
7070
register_embedding_provider(nim.NIMEmbeddingModel)
7171
register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel)
72+
register_embedding_provider(google.GoogleEmbeddingModel)
7273
register_embedding_provider(cohere.CohereEmbeddingModel)
7374

7475

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
from typing import List, Optional
18+
19+
from .base import EmbeddingModel
20+
21+
22+
class GoogleEmbeddingModel(EmbeddingModel):
23+
"""Embedding model using Gemini API.
24+
25+
This class is a wrapper for using embedding models powered by Gemini API.
26+
27+
To use, you must have either:
28+
29+
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
30+
2. Pass your API key using the api_key kwarg to the genai.Client().
31+
32+
Args:
33+
embedding_model (str): The name of the embedding model to be used.
34+
**kwargs: Additional keyword arguments. Supports:
35+
- output_dimensionality (int, optional): Desired output dimensions (128-3072 for gemini-embedding-001).
36+
Recommended values: 768, 1536, or 3072. If not specified, API defaults to 3072.
37+
- api_key (str, optional): API key for authentication (or use GOOGLE_API_KEY env var).
38+
- Other arguments passed to genai.Client() constructor.
39+
40+
Attributes:
41+
model (str): The name of the embedding model.
42+
embedding_size (int): The size of the embeddings.
43+
"""
44+
45+
engine_name = "google"
46+
47+
def __init__(self, embedding_model: str, **kwargs):
48+
try:
49+
from google import genai
50+
51+
except ImportError:
52+
raise ImportError(
53+
"Could not import google-genai, please install it with "
54+
"`pip install google-genai`."
55+
)
56+
57+
self.model = embedding_model
58+
self.output_dimensionality = kwargs.pop("output_dimensionality", None)
59+
60+
self.client = genai.Client(**kwargs)
61+
62+
embedding_size_dict = {
63+
"gemini-embedding-001": 3072,
64+
}
65+
66+
if self.model in embedding_size_dict:
67+
self._embedding_size = (
68+
self.output_dimensionality
69+
if self.output_dimensionality is not None
70+
else embedding_size_dict[self.model]
71+
)
72+
else:
73+
self._embedding_size = None
74+
75+
@property
76+
def embedding_size(self) -> int:
77+
if self._embedding_size is None:
78+
self._embedding_size = len(self.encode(["test"])[0])
79+
return self._embedding_size
80+
81+
async def encode_async(self, documents: List[str]) -> List[List[float]]:
82+
"""Encode a list of documents into their corresponding sentence embeddings.
83+
84+
Args:
85+
documents (List[str]): The list of documents to be encoded.
86+
87+
Returns:
88+
List[List[float]]: The list of sentence embeddings, where each embedding is a list of floats.
89+
"""
90+
loop = asyncio.get_running_loop()
91+
embeddings = await loop.run_in_executor(None, self.encode, documents)
92+
93+
return embeddings
94+
95+
def encode(self, documents: List[str]) -> List[List[float]]:
96+
"""Encode a list of documents into their corresponding sentence embeddings.
97+
98+
Args:
99+
documents (List[str]): The list of documents to be encoded.
100+
101+
Returns:
102+
List[List[float]]: The list of sentence embeddings, where each embedding is a list of floats.
103+
104+
Raises:
105+
RuntimeError: If the embedding request fails.
106+
"""
107+
try:
108+
embed_kwargs = {"model": self.model, "contents": documents}
109+
if self.output_dimensionality is not None:
110+
embed_kwargs["output_dimensionality"] = self.output_dimensionality
111+
112+
results = self.client.models.embed_content(**embed_kwargs)
113+
return [emb.values for emb in results.embeddings]
114+
except Exception as e:
115+
raise RuntimeError(f"Failed to retrieve embeddings: {e}") from e
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
define user ask capabilities
2+
"What can you do?"
3+
"What can you help me with?"
4+
"tell me what you can do"
5+
"tell me about you"
6+
7+
define bot inform capabilities
8+
"I am an AI assistant that helps answer questions."
9+
10+
define flow
11+
user ask capabilities
12+
bot inform capabilities
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: gpt-3.5-turbo-instruct
5+
6+
- type: embeddings
7+
engine: google
8+
model: gemini-embedding-001

tests/test_embeddings_google.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
20+
from nemoguardrails import LLMRails, RailsConfig
21+
22+
try:
23+
from nemoguardrails.embeddings.providers.google import GoogleEmbeddingModel
24+
except ImportError:
25+
GoogleEmbeddingModel = None
26+
27+
CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs")
28+
29+
LIVE_TEST_MODE = os.environ.get("LIVE_TEST")
30+
31+
32+
@pytest.fixture
33+
def app():
34+
"""Load the configuration where we replace FastEmbed with Google."""
35+
config = RailsConfig.from_path(
36+
os.path.join(CONFIGS_FOLDER, "with_google_embeddings")
37+
)
38+
39+
return LLMRails(config)
40+
41+
42+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
43+
def test_custom_llm_registration(app):
44+
assert isinstance(
45+
app.llm_generation_actions.flows_index._model, GoogleEmbeddingModel
46+
)
47+
48+
49+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
50+
@pytest.mark.asyncio
51+
async def test_live_query():
52+
config = RailsConfig.from_path(
53+
os.path.join(CONFIGS_FOLDER, "with_google_embeddings")
54+
)
55+
app = LLMRails(config)
56+
57+
result = await app.generate_async(
58+
messages=[{"role": "user", "content": "tell me what you can do"}]
59+
)
60+
61+
assert result == {
62+
"role": "assistant",
63+
"content": "I am an AI assistant that helps answer questions.",
64+
}
65+
66+
67+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
68+
def test_live_query_sync(app):
69+
result = app.generate(
70+
messages=[{"role": "user", "content": "tell me what you can do"}]
71+
)
72+
73+
assert result == {
74+
"role": "assistant",
75+
"content": "I am an AI assistant that helps answer questions.",
76+
}
77+
78+
79+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
80+
def test_sync_embeddings():
81+
model = GoogleEmbeddingModel("gemini-embedding-001")
82+
83+
result = model.encode(["test"])
84+
85+
assert len(result[0]) == 3072
86+
87+
88+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
89+
@pytest.mark.asyncio
90+
async def test_async_embeddings():
91+
model = GoogleEmbeddingModel("gemini-embedding-001")
92+
93+
result = await model.encode_async(["test"])
94+
95+
assert len(result[0]) == 3072

0 commit comments

Comments
 (0)