From f7599a07343bbf8254cb92487180fb3a6b36ab0e Mon Sep 17 00:00:00 2001 From: shadeMe Date: Wed, 6 Mar 2024 11:20:04 +0100 Subject: [PATCH] fix: `nvidia-haystack`- Handle non-strict env var secrets correctly --- .../components/embedders/nvidia/document_embedder.py | 7 ++----- .../components/embedders/nvidia/text_embedder.py | 7 ++----- .../src/haystack_integrations/utils/nvidia/client.py | 11 ++++++++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 139a184b7..bbc68b492 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -65,9 +65,6 @@ def __init__( if isinstance(model, str): model = NvidiaEmbeddingModel.from_str(model) - resolved_api_key = api_key.resolve_value() - assert resolved_api_key is not None - # Upper-limit for the endpoint. if batch_size > MAX_INPUTS: msg = f"NVIDIA Cloud Functions currently support a maximum batch size of {MAX_INPUTS}." @@ -83,7 +80,7 @@ def __init__( self.embedding_separator = embedding_separator self.client = NvidiaCloudFunctionsClient( - api_key=resolved_api_key, + api_key=api_key, headers={ "Content-Type": "application/json", "Accept": "application/json", @@ -193,7 +190,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + elif not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "NvidiaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the NvidiaTextEmbedder." diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 43d62ed92..a2636b4b8 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -53,15 +53,12 @@ def __init__( if isinstance(model, str): model = NvidiaEmbeddingModel.from_str(model) - resolved_api_key = api_key.resolve_value() - assert resolved_api_key is not None - self.api_key = api_key self.model = model self.prefix = prefix self.suffix = suffix self.client = NvidiaCloudFunctionsClient( - api_key=resolved_api_key, + api_key=api_key, headers={ "Content-Type": "application/json", "Accept": "application/json", @@ -128,7 +125,7 @@ def run(self, text: str): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - if not isinstance(text, str): + elif not isinstance(text, str): msg = ( "NvidiaTextEmbedder expects a string as an input." "In case you want to embed a list of Documents, please use the NvidiaDocumentEmbedder." diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py index 5227e8c45..e582b09ba 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py @@ -3,6 +3,7 @@ from typing import Dict, Optional import requests +from haystack.utils import Secret FUNCTIONS_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/functions" INVOKE_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions" @@ -19,13 +20,17 @@ class AvailableNvidiaCloudFunctions: class NvidiaCloudFunctionsClient: - def __init__(self, *, api_key: str, headers: Dict[str, str], timeout: int = 60): - self.api_key = api_key + def __init__(self, *, api_key: Secret, headers: Dict[str, str], timeout: int = 60): + self.api_key = api_key.resolve_value() + if self.api_key is None: + msg = "Nvidia Cloud Functions API key is not set." + raise ValueError(msg) + self.fetch_url_format = STATUS_ENDPOINT self.headers = copy.deepcopy(headers) self.headers.update( { - "Authorization": f"Bearer {api_key}", + "Authorization": f"Bearer {self.api_key}", } ) self.timeout = timeout