From 38ad0cfea1837df06e3377339de223c8f1ea546f Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 5 Dec 2023 19:38:50 +0100 Subject: [PATCH] linting --- integrations/cohere/pyproject.toml | 13 +++++-- .../embedders/document_embedder.py | 39 ++++++++++++------- .../embedders/text_embedder.py | 36 ++++++++++------- .../src/cohere_haystack/embedders/utils.py | 7 ++-- .../cohere/tests/test_document_embedder.py | 14 +++---- .../cohere/tests/test_text_embedder.py | 6 +-- 6 files changed, 68 insertions(+), 47 deletions(-) diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index a12b85992..fd7663eb6 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -97,7 +97,6 @@ select = [ "E", "EM", "F", - "FBT", "I", "ICN", "ISC", @@ -118,8 +117,6 @@ select = [ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", # Ignore checks for possible passwords "S105", "S106", "S107", # Ignore complexity @@ -172,4 +169,12 @@ addopts = "--strict-markers" markers = [ "integration: integration tests", ] -log_cli = true \ No newline at end of file +log_cli = true + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "pytest.*", + "numpy.*", +] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index e98c5358d..4f13b59a8 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Dict, Any -import os import asyncio +import os +from typing import Any, Dict, List, Optional -from haystack import component, Document, default_to_dict -from cohere import Client, AsyncClient, CohereError +from cohere import AsyncClient, Client +from haystack import Document, component, default_to_dict -from .utils import get_async_response, get_response +from cohere_haystack.embedders.utils import get_async_response, get_response API_BASE_URL = "https://api.cohere.ai/v1/embed" @@ -37,11 +37,18 @@ def __init__( """ Create a CohereDocumentEmbedder component. - :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment + variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are + `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, + `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. - :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned. - :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + cases, input is discarded until the remaining input is exactly the maximum input token length for the model. + If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + AsyncClient for applications with many concurrent calls. :param max_retries: maximal number of retries for requests, defaults to `3`. :param timeout: request timeout in seconds, defaults to `120`. :param batch_size: Number of Documents to encode at once. @@ -55,10 +62,11 @@ def __init__( try: api_key = os.environ["COHERE_API_KEY"] except KeyError as error_msg: - raise ValueError( - "CohereDocumentEmbedder expects an Cohere API key. " - "Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly." - ) from error_msg + msg = ( + "CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " + "variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) + raise ValueError(msg) from error_msg self.api_key = api_key self.model_name = model_name @@ -100,7 +108,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None ] - text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005 texts_to_embed.append(text_to_embed) return texts_to_embed @@ -114,10 +122,11 @@ def run(self, documents: List[Document]): """ if not isinstance(documents, list) or not isinstance(documents[0], Document): - raise TypeError( + msg = ( "CohereDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the CohereTextEmbedder." ) + raise TypeError(msg) texts_to_embed = self._prepare_texts_to_embed(documents) diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 0727a0823..7deb970df 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -1,15 +1,14 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Dict, Any import asyncio import os +from typing import Any, Dict, List, Optional -from haystack import component, default_to_dict, default_from_dict +from cohere import AsyncClient, Client +from haystack import component, default_to_dict -from cohere import Client, AsyncClient, CohereError - -from .utils import get_async_response, get_response +from cohere_haystack.embedders.utils import get_async_response, get_response API_BASE_URL = "https://api.cohere.ai/v1/embed" @@ -33,11 +32,18 @@ def __init__( """ Create a CohereTextEmbedder component. - :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment + variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are + `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, + `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. - :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned. - :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + cases, input is discarded until the remaining input is exactly the maximum input token length for the model. + If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + AsyncClient for applications with many concurrent calls. :param max_retries: Maximum number of retries for requests, defaults to `3`. :param timeout: Request timeout in seconds, defaults to `120`. """ @@ -46,10 +52,11 @@ def __init__( try: api_key = os.environ["COHERE_API_KEY"] except KeyError as error_msg: - raise ValueError( - "CohereTextEmbedder expects an Cohere API key. " - "Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly." - ) from error_msg + msg = ( + "CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " + "variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) + raise ValueError(msg) from error_msg self.api_key = api_key self.model_name = model_name @@ -77,10 +84,11 @@ def to_dict(self) -> Dict[str, Any]: def run(self, text: str): """Embed a string.""" if not isinstance(text, str): - raise TypeError( + msg = ( "CohereTextEmbedder expects a string as input." "In case you want to embed a list of Documents, please use the CohereDocumentEmbedder." ) + raise TypeError(msg) # Establish connection to API diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/cohere_haystack/embedders/utils.py index a3dc10af1..a3511008b 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/utils.py +++ b/integrations/cohere/src/cohere_haystack/embedders/utils.py @@ -1,11 +1,10 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Tuple, Dict, Any +from typing import Any, Dict, List, Tuple +from cohere import AsyncClient, Client, CohereError from tqdm import tqdm -from cohere import Client, AsyncClient, CohereError - API_BASE_URL = "https://api.cohere.ai/v1/embed" @@ -43,7 +42,7 @@ def get_response( desc="Calculating embeddings", ): batch = texts[i : i + batch_size] - response = cohere_client.embed(batch) + response = cohere_client.embed(batch, model=model_name, truncate=truncate) for emb in response.embeddings: all_embeddings.append(emb) embeddings = [list(map(float, emb)) for emb in response.embeddings] diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index bf68786cd..ae3b877da 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -2,12 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock import numpy as np import pytest - from haystack import Document + from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder @@ -18,11 +18,11 @@ def test_init_default(self): assert embedder.model_name == "embed-english-v2.0" assert embedder.api_base_url == "https://api.cohere.ai/v1/embed" assert embedder.truncate == "END" - assert embedder.use_async_client == False + assert embedder.use_async_client is False assert embedder.max_retries == 3 assert embedder.timeout == 120 assert embedder.batch_size == 32 - assert embedder.progress_bar == True + assert embedder.progress_bar is True assert embedder.metadata_fields_to_embed == [] assert embedder.embedding_separator == "\n" @@ -44,11 +44,11 @@ def test_init_with_parameters(self): assert embedder.model_name == "embed-multilingual-v2.0" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" - assert embedder.use_async_client == True + assert embedder.use_async_client is True assert embedder.max_retries == 5 assert embedder.timeout == 60 assert embedder.batch_size == 64 - assert embedder.progress_bar == False + assert embedder.progress_bar is False assert embedder.metadata_fields_to_embed == ["test_field"] assert embedder.embedding_separator == "-" @@ -109,7 +109,7 @@ def test_to_dict_with_custom_init_parameters(self): @pytest.mark.integration def test_run(self): embedder = MagicMock() - embedder.run = lambda x, **kwargs: np.random.rand(len(x), 2).tolist() + embedder.run = lambda x, **_: np.random.rand(len(x), 2).tolist() docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index a4c4804be..05ee0c343 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock import pytest @@ -20,7 +20,7 @@ def test_init_default(self): assert embedder.model_name == "embed-english-v2.0" assert embedder.api_base_url == "https://api.cohere.ai/v1/embed" assert embedder.truncate == "END" - assert embedder.use_async_client == False + assert embedder.use_async_client is False assert embedder.max_retries == 3 assert embedder.timeout == 120 @@ -41,7 +41,7 @@ def test_init_with_parameters(self): assert embedder.model_name == "embed-multilingual-v2.0" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" - assert embedder.use_async_client == True + assert embedder.use_async_client is True assert embedder.max_retries == 5 assert embedder.timeout == 60