Skip to content

Commit

Permalink
Optimize API key reading (#162)
Browse files Browse the repository at this point in the history
* optimize api key reading

* fmt and rm warning match
  • Loading branch information
anakin87 authored Jan 5, 2024
1 parent 03eca65 commit fae1e36
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 62 deletions.
11 changes: 7 additions & 4 deletions integrations/cohere/src/cohere_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def __init__(
"""
cohere_import.check()

api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
api_key = os.environ.get("COHERE_API_KEY")
if not api_key:
error = "CohereChatGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." # noqa: E501
raise ValueError(error)
msg = (
"CohereChatGenerator expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

if not api_base_url:
api_base_url = cohere.COHERE_API_URL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,14 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""

if api_key is None:
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
msg = (
"CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg
api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereDocumentEmbedder expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,14 @@ def __init__(
:param timeout: Request timeout in seconds, defaults to `120`.
"""

if api_key is None:
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
msg = (
"CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg
api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereTextEmbedder expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
Expand Down
8 changes: 4 additions & 4 deletions integrations/cohere/src/cohere_haystack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def __init__(
- 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include
desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10.
"""
if not api_key:
api_key = os.environ.get("COHERE_API_KEY")
api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereGenerator needs an API key to run."
"Either provide it as init parameter or set the env var COHERE_API_KEY."
"CohereGenerator expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

Expand Down
4 changes: 2 additions & 2 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_init_default(self):
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"):
with pytest.raises(ValueError):
CohereChatGenerator()

@pytest.mark.unit
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"):
with pytest.raises(ValueError):
CohereChatGenerator.from_dict(data)

@pytest.mark.unit
Expand Down
21 changes: 9 additions & 12 deletions integrations/jina/src/jina_haystack/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,19 @@ def __init__(
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
# if the user does not provide the API key, check if it is set in the module client
if api_key is None:
try:
api_key = os.environ["JINA_API_KEY"]
except KeyError as e:
msg = (
"JinaDocumentEmbedder expects a Jina API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg) from e

api_key = api_key or os.environ.get("JINA_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"JinaDocumentEmbedder expects an API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.model_name = model_name
self.prefix = prefix
self.suffix = suffix
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
Expand Down
19 changes: 9 additions & 10 deletions integrations/jina/src/jina_haystack/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ def __init__(
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
# if the user does not provide the API key, check if it is set in the module client
if api_key is None:
try:
api_key = os.environ["JINA_API_KEY"]
except KeyError as e:
msg = (
"JinaTextEmbedder expects a Jina API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg) from e

api_key = api_key or os.environ.get("JINA_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"JinaTextEmbedder expects an API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.model_name = model_name
self.prefix = prefix
Expand Down
2 changes: 1 addition & 1 deletion integrations/jina/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_init_with_parameters(self):

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
with pytest.raises(ValueError, match="JinaDocumentEmbedder expects a Jina API key"):
with pytest.raises(ValueError):
JinaDocumentEmbedder()

def test_to_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion integrations/jina/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_init_with_parameters(self):

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
with pytest.raises(ValueError, match="JinaTextEmbedder expects a Jina API key"):
with pytest.raises(ValueError):
JinaTextEmbedder()

def test_to_dict(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
api_key = api_key or os.environ.get("PINECONE_API_KEY")
if not api_key:
msg = (
"PineconeDocumentStore expects a Pinecone API key. "
"PineconeDocumentStore expects an API key. "
"Set the PINECONE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ def __init__(
self.progress_bar = progress_bar

is_hosted_api = api_url == UNSTRUCTURED_HOSTED_API_URL
if api_key is None and is_hosted_api:
try:
api_key = os.environ["UNSTRUCTURED_API_KEY"]
except KeyError as e:
msg = (
"To use the hosted version of Unstructured, you need to set the environment variable "
"UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key."
)
raise ValueError(msg) from e

api_key = api_key or os.environ.get("UNSTRUCTURED_API_KEY")
# we check whether api_key is None or an empty string
if is_hosted_api and not api_key:
msg = (
"To use the hosted version of Unstructured, you need to set the environment variable "
"UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key."
)
raise ValueError(msg)

self.api_key = api_key

def to_dict(self) -> Dict[str, Any]:
Expand Down

0 comments on commit fae1e36

Please sign in to comment.