Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Dec 5, 2023
1 parent dfc5679 commit 38ad0cf
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 47 deletions.
13 changes: 9 additions & 4 deletions integrations/cohere/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ select = [
"E",
"EM",
"F",
"FBT",
"I",
"ICN",
"ISC",
Expand All @@ -118,8 +117,6 @@ select = [
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore complexity
Expand Down Expand Up @@ -172,4 +169,12 @@ addopts = "--strict-markers"
markers = [
"integration: integration tests",
]
log_cli = true
log_cli = true

[[tool.mypy.overrides]]
module = [
"haystack.*",
"pytest.*",
"numpy.*",
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Dict, Any
import os
import asyncio
import os
from typing import Any, Dict, List, Optional

from haystack import component, Document, default_to_dict
from cohere import Client, AsyncClient, CohereError
from cohere import AsyncClient, Client
from haystack import Document, component, default_to_dict

from .utils import get_async_response, get_response
from cohere_haystack.embedders.utils import get_async_response, get_response

API_BASE_URL = "https://api.cohere.ai/v1/embed"

Expand Down Expand Up @@ -37,11 +37,18 @@ def __init__(
"""
Create a CohereDocumentEmbedder component.
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`,
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param max_retries: maximal number of retries for requests, defaults to `3`.
:param timeout: request timeout in seconds, defaults to `120`.
:param batch_size: Number of Documents to encode at once.
Expand All @@ -55,10 +62,11 @@ def __init__(
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
raise ValueError(
"CohereDocumentEmbedder expects an Cohere API key. "
"Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly."
) from error_msg
msg = (
"CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg

self.api_key = api_key
self.model_name = model_name
Expand Down Expand Up @@ -100,7 +108,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None
]

text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005
texts_to_embed.append(text_to_embed)
return texts_to_embed

Expand All @@ -114,10 +122,11 @@ def run(self, documents: List[Document]):
"""

if not isinstance(documents, list) or not isinstance(documents[0], Document):
raise TypeError(
msg = (
"CohereDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a string, please use the CohereTextEmbedder."
)
raise TypeError(msg)

texts_to_embed = self._prepare_texts_to_embed(documents)

Expand Down
36 changes: 22 additions & 14 deletions integrations/cohere/src/cohere_haystack/embedders/text_embedder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Dict, Any
import asyncio
import os
from typing import Any, Dict, List, Optional

from haystack import component, default_to_dict, default_from_dict
from cohere import AsyncClient, Client
from haystack import component, default_to_dict

from cohere import Client, AsyncClient, CohereError

from .utils import get_async_response, get_response
from cohere_haystack.embedders.utils import get_async_response, get_response

API_BASE_URL = "https://api.cohere.ai/v1/embed"

Expand All @@ -33,11 +32,18 @@ def __init__(
"""
Create a CohereTextEmbedder component.
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`,
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param max_retries: Maximum number of retries for requests, defaults to `3`.
:param timeout: Request timeout in seconds, defaults to `120`.
"""
Expand All @@ -46,10 +52,11 @@ def __init__(
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
raise ValueError(
"CohereTextEmbedder expects an Cohere API key. "
"Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly."
) from error_msg
msg = (
"CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg

self.api_key = api_key
self.model_name = model_name
Expand Down Expand Up @@ -77,10 +84,11 @@ def to_dict(self) -> Dict[str, Any]:
def run(self, text: str):
"""Embed a string."""
if not isinstance(text, str):
raise TypeError(
msg = (
"CohereTextEmbedder expects a string as input."
"In case you want to embed a list of Documents, please use the CohereDocumentEmbedder."
)
raise TypeError(msg)

# Establish connection to API

Expand Down
7 changes: 3 additions & 4 deletions integrations/cohere/src/cohere_haystack/embedders/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple, Dict, Any
from typing import Any, Dict, List, Tuple

from cohere import AsyncClient, Client, CohereError
from tqdm import tqdm
from cohere import Client, AsyncClient, CohereError


API_BASE_URL = "https://api.cohere.ai/v1/embed"

Expand Down Expand Up @@ -43,7 +42,7 @@ def get_response(
desc="Calculating embeddings",
):
batch = texts[i : i + batch_size]
response = cohere_client.embed(batch)
response = cohere_client.embed(batch, model=model_name, truncate=truncate)
for emb in response.embeddings:
all_embeddings.append(emb)
embeddings = [list(map(float, emb)) for emb in response.embeddings]
Expand Down
14 changes: 7 additions & 7 deletions integrations/cohere/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock

import numpy as np
import pytest

from haystack import Document

from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder


Expand All @@ -18,11 +18,11 @@ def test_init_default(self):
assert embedder.model_name == "embed-english-v2.0"
assert embedder.api_base_url == "https://api.cohere.ai/v1/embed"
assert embedder.truncate == "END"
assert embedder.use_async_client == False
assert embedder.use_async_client is False
assert embedder.max_retries == 3
assert embedder.timeout == 120
assert embedder.batch_size == 32
assert embedder.progress_bar == True
assert embedder.progress_bar is True
assert embedder.metadata_fields_to_embed == []
assert embedder.embedding_separator == "\n"

Expand All @@ -44,11 +44,11 @@ def test_init_with_parameters(self):
assert embedder.model_name == "embed-multilingual-v2.0"
assert embedder.api_base_url == "https://custom-api-base-url.com"
assert embedder.truncate == "START"
assert embedder.use_async_client == True
assert embedder.use_async_client is True
assert embedder.max_retries == 5
assert embedder.timeout == 60
assert embedder.batch_size == 64
assert embedder.progress_bar == False
assert embedder.progress_bar is False
assert embedder.metadata_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == "-"

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_to_dict_with_custom_init_parameters(self):
@pytest.mark.integration
def test_run(self):
embedder = MagicMock()
embedder.run = lambda x, **kwargs: np.random.rand(len(x), 2).tolist()
embedder.run = lambda x, **_: np.random.rand(len(x), 2).tolist()

docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Expand Down
6 changes: 3 additions & 3 deletions integrations/cohere/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock

import pytest

Expand All @@ -20,7 +20,7 @@ def test_init_default(self):
assert embedder.model_name == "embed-english-v2.0"
assert embedder.api_base_url == "https://api.cohere.ai/v1/embed"
assert embedder.truncate == "END"
assert embedder.use_async_client == False
assert embedder.use_async_client is False
assert embedder.max_retries == 3
assert embedder.timeout == 120

Expand All @@ -41,7 +41,7 @@ def test_init_with_parameters(self):
assert embedder.model_name == "embed-multilingual-v2.0"
assert embedder.api_base_url == "https://custom-api-base-url.com"
assert embedder.truncate == "START"
assert embedder.use_async_client == True
assert embedder.use_async_client is True
assert embedder.max_retries == 5
assert embedder.timeout == 60

Expand Down

0 comments on commit 38ad0cf

Please sign in to comment.