From ec89260328b9d5e12160249eb826ddf61671e8c7 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 2 Jan 2024 12:34:32 +0100 Subject: [PATCH 1/2] optimize api key reading --- .../cohere_haystack/chat/chat_generator.py | 11 ++++++---- .../embedders/document_embedder.py | 17 +++++++-------- .../embedders/text_embedder.py | 17 +++++++-------- .../cohere/src/cohere_haystack/generator.py | 8 +++---- .../src/jina_haystack/document_embedder.py | 21 ++++++++----------- .../jina/src/jina_haystack/text_embedder.py | 19 ++++++++--------- .../src/pinecone_haystack/document_store.py | 2 +- .../fileconverter.py | 19 +++++++++-------- 8 files changed, 56 insertions(+), 58 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py index f3178d567..be236f6ca 100644 --- a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py +++ b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py @@ -68,11 +68,14 @@ def __init__( """ cohere_import.check() + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string if not api_key: - api_key = os.environ.get("COHERE_API_KEY") - if not api_key: - error = "CohereChatGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." # noqa: E501 - raise ValueError(error) + msg = ( + "CohereChatGenerator expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) if not api_base_url: api_base_url = cohere.COHERE_API_URL diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index 6c87f3537..bc0b9381d 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -78,15 +78,14 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - if api_key is None: - try: - api_key = os.environ["COHERE_API_KEY"] - except KeyError as error_msg: - msg = ( - "CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " - "variable COHERE_API_KEY (recommended) or by passing it explicitly." - ) - raise ValueError(msg) from error_msg + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "CohereDocumentEmbedder expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.api_key = api_key self.model_name = model_name diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 25822223e..4ba8acd47 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -67,15 +67,14 @@ def __init__( :param timeout: Request timeout in seconds, defaults to `120`. """ - if api_key is None: - try: - api_key = os.environ["COHERE_API_KEY"] - except KeyError as error_msg: - msg = ( - "CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " - "variable COHERE_API_KEY (recommended) or by passing it explicitly." - ) - raise ValueError(msg) from error_msg + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "CohereTextEmbedder expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.api_key = api_key self.model_name = model_name diff --git a/integrations/cohere/src/cohere_haystack/generator.py b/integrations/cohere/src/cohere_haystack/generator.py index 571464c0c..66c80afa4 100644 --- a/integrations/cohere/src/cohere_haystack/generator.py +++ b/integrations/cohere/src/cohere_haystack/generator.py @@ -73,12 +73,12 @@ def __init__( - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. """ - if not api_key: - api_key = os.environ.get("COHERE_API_KEY") + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string if not api_key: msg = ( - "CohereGenerator needs an API key to run." - "Either provide it as init parameter or set the env var COHERE_API_KEY." + "CohereGenerator expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) diff --git a/integrations/jina/src/jina_haystack/document_embedder.py b/integrations/jina/src/jina_haystack/document_embedder.py index 03f64462f..9f51a9e26 100644 --- a/integrations/jina/src/jina_haystack/document_embedder.py +++ b/integrations/jina/src/jina_haystack/document_embedder.py @@ -57,22 +57,19 @@ def __init__( :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - # if the user does not provide the API key, check if it is set in the module client - if api_key is None: - try: - api_key = os.environ["JINA_API_KEY"] - except KeyError as e: - msg = ( - "JinaDocumentEmbedder expects a Jina API key. " - "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) from e + + api_key = api_key or os.environ.get("JINA_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "JinaDocumentEmbedder expects an API key. " + "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.model_name = model_name self.prefix = prefix self.suffix = suffix - self.prefix = prefix - self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] diff --git a/integrations/jina/src/jina_haystack/text_embedder.py b/integrations/jina/src/jina_haystack/text_embedder.py index 5b29bef6d..7a21e61a6 100644 --- a/integrations/jina/src/jina_haystack/text_embedder.py +++ b/integrations/jina/src/jina_haystack/text_embedder.py @@ -47,16 +47,15 @@ def __init__( :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. """ - # if the user does not provide the API key, check if it is set in the module client - if api_key is None: - try: - api_key = os.environ["JINA_API_KEY"] - except KeyError as e: - msg = ( - "JinaTextEmbedder expects a Jina API key. " - "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) from e + + api_key = api_key or os.environ.get("JINA_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "JinaTextEmbedder expects an API key. " + "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.model_name = model_name self.prefix = prefix diff --git a/integrations/pinecone/src/pinecone_haystack/document_store.py b/integrations/pinecone/src/pinecone_haystack/document_store.py index d6296e030..252437a9a 100644 --- a/integrations/pinecone/src/pinecone_haystack/document_store.py +++ b/integrations/pinecone/src/pinecone_haystack/document_store.py @@ -60,7 +60,7 @@ def __init__( api_key = api_key or os.environ.get("PINECONE_API_KEY") if not api_key: msg = ( - "PineconeDocumentStore expects a Pinecone API key. " + "PineconeDocumentStore expects an API key. " "Set the PINECONE_API_KEY environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) diff --git a/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py b/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py index 5a565d00b..d94cb49c4 100644 --- a/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py +++ b/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py @@ -60,15 +60,16 @@ def __init__( self.progress_bar = progress_bar is_hosted_api = api_url == UNSTRUCTURED_HOSTED_API_URL - if api_key is None and is_hosted_api: - try: - api_key = os.environ["UNSTRUCTURED_API_KEY"] - except KeyError as e: - msg = ( - "To use the hosted version of Unstructured, you need to set the environment variable " - "UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key." - ) - raise ValueError(msg) from e + + api_key = api_key or os.environ.get("UNSTRUCTURED_API_KEY") + # we check whether api_key is None or an empty string + if is_hosted_api and not api_key: + msg = ( + "To use the hosted version of Unstructured, you need to set the environment variable " + "UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key." + ) + raise ValueError(msg) + self.api_key = api_key def to_dict(self) -> Dict[str, Any]: From 5dbacb6bbf47e4e34ecd6d0d7f6070a0d9c6bbd8 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 2 Jan 2024 15:21:45 +0100 Subject: [PATCH 2/2] fmt and rm warning match --- integrations/cohere/tests/test_cohere_chat_generator.py | 4 ++-- integrations/jina/src/jina_haystack/text_embedder.py | 2 +- integrations/jina/tests/test_document_embedder.py | 2 +- integrations/jina/tests/test_text_embedder.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 92954df8b..f9ac7b2c6 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -65,7 +65,7 @@ def test_init_default(self): @pytest.mark.unit def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) - with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"): + with pytest.raises(ValueError): CohereChatGenerator() @pytest.mark.unit @@ -167,7 +167,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"): + with pytest.raises(ValueError): CohereChatGenerator.from_dict(data) @pytest.mark.unit diff --git a/integrations/jina/src/jina_haystack/text_embedder.py b/integrations/jina/src/jina_haystack/text_embedder.py index 7a21e61a6..f717f4748 100644 --- a/integrations/jina/src/jina_haystack/text_embedder.py +++ b/integrations/jina/src/jina_haystack/text_embedder.py @@ -47,7 +47,7 @@ def __init__( :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. """ - + api_key = api_key or os.environ.get("JINA_API_KEY") # we check whether api_key is None or an empty string if not api_key: diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index ac8bb6975..2ebc5d358 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -58,7 +58,7 @@ def test_init_with_parameters(self): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("JINA_API_KEY", raising=False) - with pytest.raises(ValueError, match="JinaDocumentEmbedder expects a Jina API key"): + with pytest.raises(ValueError): JinaDocumentEmbedder() def test_to_dict(self): diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py index c8a730c2f..7dfd64a05 100644 --- a/integrations/jina/tests/test_text_embedder.py +++ b/integrations/jina/tests/test_text_embedder.py @@ -32,7 +32,7 @@ def test_init_with_parameters(self): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("JINA_API_KEY", raising=False) - with pytest.raises(ValueError, match="JinaTextEmbedder expects a Jina API key"): + with pytest.raises(ValueError): JinaTextEmbedder() def test_to_dict(self):