-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
68 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import List, Optional, Dict, Any | ||
import os | ||
import asyncio | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from haystack import component, Document, default_to_dict | ||
from cohere import Client, AsyncClient, CohereError | ||
from cohere import AsyncClient, Client | ||
from haystack import Document, component, default_to_dict | ||
|
||
from .utils import get_async_response, get_response | ||
from cohere_haystack.embedders.utils import get_async_response, get_response | ||
|
||
API_BASE_URL = "https://api.cohere.ai/v1/embed" | ||
|
||
|
@@ -37,11 +37,18 @@ def __init__( | |
""" | ||
Create a CohereDocumentEmbedder component. | ||
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). | ||
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. | ||
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment | ||
variable COHERE_API_KEY (recommended). | ||
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are | ||
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, | ||
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. | ||
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. | ||
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned. | ||
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls. | ||
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to | ||
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both | ||
cases, input is discarded until the remaining input is exactly the maximum input token length for the model. | ||
If NONE is selected, when the input exceeds the maximum input token length an error will be returned. | ||
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use | ||
AsyncClient for applications with many concurrent calls. | ||
:param max_retries: maximal number of retries for requests, defaults to `3`. | ||
:param timeout: request timeout in seconds, defaults to `120`. | ||
:param batch_size: Number of Documents to encode at once. | ||
|
@@ -55,10 +62,11 @@ def __init__( | |
try: | ||
api_key = os.environ["COHERE_API_KEY"] | ||
except KeyError as error_msg: | ||
raise ValueError( | ||
"CohereDocumentEmbedder expects an Cohere API key. " | ||
"Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly." | ||
) from error_msg | ||
msg = ( | ||
"CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " | ||
"variable COHERE_API_KEY (recommended) or by passing it explicitly." | ||
) | ||
raise ValueError(msg) from error_msg | ||
|
||
self.api_key = api_key | ||
self.model_name = model_name | ||
|
@@ -100,7 +108,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: | |
str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None | ||
] | ||
|
||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) | ||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005 | ||
texts_to_embed.append(text_to_embed) | ||
return texts_to_embed | ||
|
||
|
@@ -114,10 +122,11 @@ def run(self, documents: List[Document]): | |
""" | ||
|
||
if not isinstance(documents, list) or not isinstance(documents[0], Document): | ||
raise TypeError( | ||
msg = ( | ||
"CohereDocumentEmbedder expects a list of Documents as input." | ||
"In case you want to embed a string, please use the CohereTextEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
texts_to_embed = self._prepare_texts_to_embed(documents) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,14 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import List, Optional, Dict, Any | ||
import asyncio | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from haystack import component, default_to_dict, default_from_dict | ||
from cohere import AsyncClient, Client | ||
from haystack import component, default_to_dict | ||
|
||
from cohere import Client, AsyncClient, CohereError | ||
|
||
from .utils import get_async_response, get_response | ||
from cohere_haystack.embedders.utils import get_async_response, get_response | ||
|
||
API_BASE_URL = "https://api.cohere.ai/v1/embed" | ||
|
||
|
@@ -33,11 +32,18 @@ def __init__( | |
""" | ||
Create a CohereTextEmbedder component. | ||
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). | ||
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. | ||
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment | ||
variable COHERE_API_KEY (recommended). | ||
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are | ||
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, | ||
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. | ||
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. | ||
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned. | ||
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls. | ||
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to | ||
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both | ||
cases, input is discarded until the remaining input is exactly the maximum input token length for the model. | ||
If NONE is selected, when the input exceeds the maximum input token length an error will be returned. | ||
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use | ||
AsyncClient for applications with many concurrent calls. | ||
:param max_retries: Maximum number of retries for requests, defaults to `3`. | ||
:param timeout: Request timeout in seconds, defaults to `120`. | ||
""" | ||
|
@@ -46,10 +52,11 @@ def __init__( | |
try: | ||
api_key = os.environ["COHERE_API_KEY"] | ||
except KeyError as error_msg: | ||
raise ValueError( | ||
"CohereTextEmbedder expects an Cohere API key. " | ||
"Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly." | ||
) from error_msg | ||
msg = ( | ||
"CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " | ||
"variable COHERE_API_KEY (recommended) or by passing it explicitly." | ||
) | ||
raise ValueError(msg) from error_msg | ||
|
||
self.api_key = api_key | ||
self.model_name = model_name | ||
|
@@ -77,10 +84,11 @@ def to_dict(self) -> Dict[str, Any]: | |
def run(self, text: str): | ||
"""Embed a string.""" | ||
if not isinstance(text, str): | ||
raise TypeError( | ||
msg = ( | ||
"CohereTextEmbedder expects a string as input." | ||
"In case you want to embed a list of Documents, please use the CohereDocumentEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
# Establish connection to API | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import List, Tuple, Dict, Any | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from cohere import AsyncClient, Client, CohereError | ||
from tqdm import tqdm | ||
from cohere import Client, AsyncClient, CohereError | ||
|
||
|
||
API_BASE_URL = "https://api.cohere.ai/v1/embed" | ||
|
||
|
@@ -43,7 +42,7 @@ def get_response( | |
desc="Calculating embeddings", | ||
): | ||
batch = texts[i : i + batch_size] | ||
response = cohere_client.embed(batch) | ||
response = cohere_client.embed(batch, model=model_name, truncate=truncate) | ||
for emb in response.embeddings: | ||
all_embeddings.append(emb) | ||
embeddings = [list(map(float, emb)) for emb in response.embeddings] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters