Skip to content

Commit 4de5b2f

Browse files
authored
feat: add Cohere embedding integration (#1305)
1 parent 6cf9f82 commit 4de5b2f

File tree

7 files changed

+245
-11
lines changed

7 files changed

+245
-11
lines changed

docs/user-guides/configuration-guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ The following tables lists the supported embedding providers:
539539
| SentenceTransformers | `SentenceTransformers` | `all-MiniLM-L6-v2`, etc. |
540540
| NVIDIA AI Endpoints | `nvidia_ai_endpoints` | `nv-embed-v1`, etc. |
541541
| AzureOpenAI | `AzureOpenAI` | `text-embedding-ada-002`, etc.
542+
| Cohere | `cohere` | `embed-multilingual-v3.0`, etc. |
542543

543544
```{note}
544545
You can use any of the supported models for any of the supported embedding providers.

nemoguardrails/embeddings/providers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import Optional, Type
2020

21-
from . import azureopenai, fastembed, nim, openai, sentence_transformers
21+
from . import azureopenai, cohere, fastembed, nim, openai, sentence_transformers
2222
from .base import EmbeddingModel
2323
from .registry import EmbeddingProviderRegistry
2424

@@ -69,6 +69,7 @@ def register_embedding_provider(
6969
register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel)
7070
register_embedding_provider(nim.NIMEmbeddingModel)
7171
register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel)
72+
register_embedding_provider(cohere.CohereEmbeddingModel)
7273

7374

7475
def init_embedding_model(
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import asyncio
16+
from contextvars import ContextVar
17+
from typing import List
18+
19+
from .base import EmbeddingModel
20+
21+
# We set the Cohere async client in an asyncio context variable because we need it
22+
# to be scoped at the asyncio loop level. The client caches it somewhere, and if the loop
23+
# is changed, it will fail.
24+
async_client_var: ContextVar = ContextVar("async_client", default=None)
25+
26+
27+
class CohereEmbeddingModel(EmbeddingModel):
28+
"""
29+
Embedding model using Cohere API.
30+
31+
To use, you must have either:
32+
1. The ``COHERE_API_KEY`` environment variable set with your API key, or
33+
2. Pass your API key using the api_key kwarg to the Cohere constructor.
34+
35+
Args:
36+
embedding_model (str): The name of the embedding model.
37+
input_type (str): The type of input for the embedding model, default is "search_document".
38+
"search_document", "search_query", "classification", "clustering", "image"
39+
40+
Attributes:
41+
model (str): The name of the embedding model.
42+
embedding_size (int): The size of the embeddings.
43+
44+
Methods:
45+
encode: Encode a list of documents into embeddings.
46+
"""
47+
48+
engine_name = "cohere"
49+
50+
def __init__(
51+
self,
52+
embedding_model: str,
53+
input_type: str = "search_document",
54+
**kwargs,
55+
):
56+
try:
57+
import cohere
58+
from cohere import AsyncClient, Client
59+
except ImportError:
60+
raise ImportError(
61+
"Could not import cohere, please install it with "
62+
"`pip install cohere`."
63+
)
64+
65+
self.model = embedding_model
66+
self.input_type = input_type
67+
self.client = cohere.Client(**kwargs)
68+
69+
self.embedding_size_dict = {
70+
"embed-v4.0": 1536,
71+
"embed-english-v3.0": 1024,
72+
"embed-english-light-v3.0": 384,
73+
"embed-multilingual-v3.0": 1024,
74+
"embed-multilingual-light-v3.0": 384,
75+
}
76+
77+
if self.model in self.embedding_size_dict:
78+
self.embedding_size = self.embedding_size_dict[self.model]
79+
else:
80+
# Perform a first encoding to get the embedding size
81+
self.embedding_size = len(self.encode(["test"])[0])
82+
83+
async def encode_async(self, documents: List[str]) -> List[List[float]]:
84+
"""Encode a list of documents into embeddings.
85+
86+
Args:
87+
documents (List[str]): The list of documents to be encoded.
88+
89+
Returns:
90+
List[List[float]]: The encoded embeddings.
91+
92+
"""
93+
loop = asyncio.get_running_loop()
94+
embeddings = await loop.run_in_executor(None, self.encode, documents)
95+
96+
# NOTE: The async implementation below has some edge cases because of
97+
# httpx and async and returns "Event loop is closed." errors. Falling back to
98+
# a thread-based implementation for now.
99+
100+
# # We do lazy initialization of the async client to make sure it's on the correct loop
101+
# async_client = async_client_var.get()
102+
# if async_client is None:
103+
# async_client = AsyncClient()
104+
# async_client_var.set(async_client)
105+
#
106+
# # Make embedding request to Cohere API
107+
# embeddings = await async_client.embed(texts=documents, model=self.model, input_type=self.input_type).embeddings
108+
109+
return embeddings
110+
111+
def encode(self, documents: List[str]) -> List[List[float]]:
112+
"""Encode a list of documents into embeddings.
113+
114+
Args:
115+
documents (List[str]): The list of documents to be encoded.
116+
117+
Returns:
118+
List[List[float]]: The encoded embeddings.
119+
120+
"""
121+
122+
# Make embedding request to Cohere API
123+
return self.client.embed(
124+
texts=documents, model=self.model, input_type=self.input_type
125+
).embeddings
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
define user ask capabilities
2+
"What can you do?"
3+
"What can you help me with?"
4+
"tell me what you can do"
5+
"tell me about you"
6+
7+
define bot inform capabilities
8+
"I am an AI assistant that helps answer questions."
9+
10+
define flow
11+
user ask capabilities
12+
bot inform capabilities
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: gpt-3.5-turbo-instruct
5+
6+
- type: embeddings
7+
engine: cohere
8+
model: embed-multilingual-v3.0

tests/test_embeddings_cohere.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
20+
from nemoguardrails import LLMRails, RailsConfig
21+
22+
try:
23+
from nemoguardrails.embeddings.providers.cohere import CohereEmbeddingModel
24+
except ImportError:
25+
# Ignore this if running in test environment when cohere not installed.
26+
CohereEmbeddingModel = None
27+
28+
CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs")
29+
30+
LIVE_TEST_MODE = os.environ.get("LIVE_TEST")
31+
32+
33+
@pytest.fixture
34+
def app():
35+
"""Load the configuration where we replace FastEmbed with Cohere."""
36+
config = RailsConfig.from_path(
37+
os.path.join(CONFIGS_FOLDER, "with_cohere_embeddings")
38+
)
39+
40+
return LLMRails(config)
41+
42+
43+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
44+
def test_custom_llm_registration(app):
45+
assert isinstance(
46+
app.llm_generation_actions.flows_index._model, CohereEmbeddingModel
47+
)
48+
49+
50+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
51+
@pytest.mark.asyncio
52+
async def test_live_query():
53+
config = RailsConfig.from_path(
54+
os.path.join(CONFIGS_FOLDER, "with_cohere_embeddings")
55+
)
56+
app = LLMRails(config)
57+
58+
result = await app.generate_async(
59+
messages=[{"role": "user", "content": "tell me what you can do"}]
60+
)
61+
62+
assert result == {
63+
"role": "assistant",
64+
"content": "I am an AI assistant that helps answer questions.",
65+
}
66+
67+
68+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
69+
@pytest.mark.asyncio
70+
def test_live_query(app):
71+
result = app.generate(
72+
messages=[{"role": "user", "content": "tell me what you can do"}]
73+
)
74+
75+
assert result == {
76+
"role": "assistant",
77+
"content": "I am an AI assistant that helps answer questions.",
78+
}
79+
80+
81+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
82+
def test_sync_embeddings():
83+
model = CohereEmbeddingModel("embed-multilingual-v3.0")
84+
85+
result = model.encode(["test"])
86+
87+
assert len(result[0]) == 1024
88+
89+
90+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
91+
@pytest.mark.asyncio
92+
async def test_async_embeddings():
93+
model = CohereEmbeddingModel("embed-multilingual-v3.0")
94+
95+
result = await model.encode_async(["test"])
96+
97+
assert len(result[0]) == 1024

tests/test_embeddings_providers_mock.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,7 @@
1818

1919
import pytest
2020

21-
try:
22-
import nemoguardrails.embeddings.providers.cohere
2321

24-
COHERE_AVAILABLE = True
25-
except (ImportError, ModuleNotFoundError):
26-
COHERE_AVAILABLE = False
27-
28-
29-
@pytest.mark.skipif(
30-
not COHERE_AVAILABLE, reason="Cohere provider not available in this branch"
31-
)
3222
class TestCohereEmbeddingModelMocked:
3323
def test_init_with_known_model(self):
3424
mock_cohere = MagicMock()

0 commit comments

Comments
 (0)