diff --git a/app.py b/app.py index ddff3ea1..907313f1 100644 --- a/app.py +++ b/app.py @@ -22,6 +22,7 @@ CreateDirectoryRequestError, ServerError, PipelineNotFoundError, + ConnectionStringError ) load_dotenv() @@ -65,12 +66,11 @@ async def delete_directory(): """ try: data = await request.get_json() - connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"] container_name = data["container_name"] folder_name = data["folder_name"] if container_name and folder_name: container_client = await azure_storage_api.mount_container( - connection_string, container_name, create_container=False + app.config["BLOB_CLIENT"], container_name, create_container=False ) if container_client: folder_uuid = await azure_storage_api.get_folder_uuid( @@ -101,11 +101,10 @@ async def list_directories(): """ try: data = await request.get_json() - connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"] container_name = data["container_name"] if container_name: container_client = await azure_storage_api.mount_container( - connection_string, container_name, create_container=True + app.config["BLOB_CLIENT"], container_name, create_container=True ) response = await azure_storage_api.get_directories(container_client) return jsonify(response), 200 @@ -124,12 +123,11 @@ async def create_directory(): """ try: data = await request.get_json() - connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"] container_name = data["container_name"] folder_name = data["folder_name"] if container_name and folder_name: container_client = await azure_storage_api.mount_container( - connection_string, container_name, create_container=False + app.config["BLOB_CLIENT"], container_name, create_container=False ) response = await azure_storage_api.create_folder( container_client, folder_name @@ -164,6 +162,7 @@ async def inference_request(): image_base64 = data["image"] pipelines_endpoints = CACHE.get("pipelines") + blob_service_client = app.config.get("BLOB_CLIENT") if not (folder_name and container_name and imageDims and image_base64): raise InferenceRequestError( @@ -181,7 +180,7 @@ async def inference_request(): image_bytes = base64.b64decode(encoded_data) container_client = await azure_storage_api.mount_container( - connection_string, container_name, create_container=True + blob_service_client, container_name, create_container=True ) hash_value = await azure_storage_api.generate_hash(image_bytes) blob_name = await azure_storage_api.upload_image( @@ -220,14 +219,10 @@ async def inference_request(): print(f"Took: {'{:10.4f}'.format(time.perf_counter() - seconds)} seconds") return jsonify(processed_result_json), 200 - except InferenceRequestError as error: + except (KeyError, InferenceRequestError) as error: print(error) return jsonify(["InferenceRequestError: " + error.args[0]]), 400 - except Exception as error: - print(error) - return jsonify(["Unexpected error occured"]), 500 - @app.get("/seed-data/") async def get_seed_data(seed_name): @@ -318,12 +313,12 @@ async def get_pipelines(): - list: A list of dictionaries representing the pipelines. """ try: - # TO DO instantiate Blob Service Client - result_json = await azure_storage_api.get_pipeline_info(connection_string, PIPELINE_BLOB_NAME, PIPELINE_VERSION) + app.config["BLOB_CLIENT"] = await azure_storage_api.get_blob_client(connection_string) + result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], PIPELINE_BLOB_NAME, PIPELINE_VERSION) cipher_suite = Fernet(FERNET_KEY) - except PipelineNotFoundError as error: + except (ConnectionStringError, PipelineNotFoundError) as error: print(error) - raise ServerError("Server errror: Pipelines were not found") from error + raise ServerError("server errror: could not retrieve the pipelines") from error models = () @@ -331,8 +326,8 @@ async def get_pipelines(): m = Model( request_function.get(model.get("api_call_function")), model.get("model_name"), - # To protect sensible data (API key and model endpoint), we crypt them when - # they are pushed into the blob storage. Once we retrieve the data here in the + # To protect sensible data (API key and model endpoint), we encrypt it when + # it's pushed into the blob storage. Once we retrieve the data here in the # backend, we need to decrypt the byte format to recover the original # data. cipher_suite.decrypt(model.get("endpoint").encode()).decode(), diff --git a/azure_storage/azure_storage_api.py b/azure_storage/azure_storage_api.py index 7d2d599d..1a4895d9 100644 --- a/azure_storage/azure_storage_api.py +++ b/azure_storage/azure_storage_api.py @@ -39,37 +39,49 @@ async def generate_hash(image): except GenerateHashError as error: print(error) +async def get_blob_client(connection_string: str): + """ + given a connection string, returns the blob client object + """ + try: + blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + if blob_service_client is None: + raise ValueError(f"the given connection string is invalid: {connection_string}") + return blob_service_client + + except ValueError as error: + print(error) + raise ConnectionStringError(error.args[0]) from error -async def mount_container(connection_string, container_uuid, create_container=True): + +async def mount_container( + blob_service_client: BlobServiceClient, + container_uuid: str, + create_container: bool =True): """ - given a connection string and a container name, mounts the container and + given a connection string and a container uuid, mounts the container and returns the container client as an object that can be used in other functions. if a specified container doesnt exist, it creates one with the provided uuid, if create_container is True """ try: - blob_service_client = BlobServiceClient.from_connection_string( - connection_string - ) - if blob_service_client: - container_name = "user-{}".format(container_uuid) - container_client = blob_service_client.get_container_client(container_name) - if container_client.exists(): + container_name = "user-{}".format(container_uuid) + container_client = blob_service_client.get_container_client(container_name) + if container_client.exists(): + return container_client + elif create_container and not container_client.exists(): + container_client = blob_service_client.create_container(container_name) + # create general directory for new user container + response = await create_folder(container_client, "General") + if response: return container_client - elif create_container and not container_client.exists(): - container_client = blob_service_client.create_container(container_name) - # create general directory for new user container - response = await create_folder(container_client, "General") - if response: - return container_client - else: - return False - else: - raise ConnectionStringError("Invalid connection string") - + else: + raise MountContainerError(f"could not create general directory: {container_name}") except MountContainerError as error: print(error) - return False + raise async def get_blob(container_client: ContainerClient, blob_name: str): @@ -237,7 +249,7 @@ async def get_directories(container_client): return [] async def get_pipeline_info( - connection_string: str, + blob_service_client: BlobServiceClient, pipeline_container_name: str, pipeline_version: str ) -> json: @@ -246,10 +258,10 @@ async def get_pipeline_info( provided parameters. Args: - connection_string (str): The connection string for the Azure Blob - Storage. pipeline_container_name (str): The name of the container where - the pipeline files are stored. pipeline_version (str): The version of - the pipeline to retrieve. + blob_service_client (BlobServiceClient): The BlobServiceClient object + pipeline_container_name (str): The name of the container where + the pipeline files are stored. + pipeline_version (str): The version of the pipeline to retrieve. Returns: json: The pipeline information in JSON format. @@ -259,13 +271,6 @@ async def get_pipeline_info( found. """ try: - blob_service_client = BlobServiceClient.from_connection_string( - connection_string - ) - - if blob_service_client is None: - raise PipelineNotFoundError("No Blob Service Client found with the connection string.") - container_client = blob_service_client.get_container_client( pipeline_container_name ) @@ -274,5 +279,5 @@ async def get_pipeline_info( pipeline = json.loads(blob) return pipeline - except (ValueError, GetBlobError, PipelineNotFoundError) as error: + except GetBlobError as error: raise PipelineNotFoundError(f"This version {pipeline_version} was not found") from error diff --git a/model/inference.py b/model/inference.py index 3d1640d2..434c998e 100644 --- a/model/inference.py +++ b/model/inference.py @@ -1,6 +1,6 @@ """ -This file contain the generic inference function that process the data at the end +This file contains the generic inference function that processes the data at the end of a given pipeline. """ diff --git a/model/seed_detector.py b/model/seed_detector.py index d1d87417..4229ca3f 100644 --- a/model/seed_detector.py +++ b/model/seed_detector.py @@ -1,5 +1,5 @@ """ -This file contain the function to request the inference and process the data from +This file contains the function that requests the inference and processes the data from the seed detector model. """ @@ -59,7 +59,7 @@ def process_image_slicing(image_bytes: bytes, result_json: dict) -> list: async def request_inference_from_seed_detector(model: namedtuple, previous_result: str): """ - Requests inference from the seed detector model using the provided previous result. + Requests inference from the seed detector model using the previously provided result. Args: model (namedtuple): The seed detector model. diff --git a/model/six_seeds.py b/model/six_seeds.py index c6436b6b..38e0e755 100644 --- a/model/six_seeds.py +++ b/model/six_seeds.py @@ -1,5 +1,5 @@ """ -This file contain the function to request the inference and process the data from +This file contains the function that requests the inference and processes the data from the nachet-6seeds model. """ @@ -48,4 +48,4 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul return result_object except HTTPError as e: - raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None \ No newline at end of file + raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None diff --git a/model/swin.py b/model/swin.py index 07614c0b..78630d48 100644 --- a/model/swin.py +++ b/model/swin.py @@ -1,6 +1,6 @@ """ -This file contain the function to request the inference and process the data from -the nachet-6seeds model. +This file contains the function that requests the inference and processes the data from +the swin model. """ import json @@ -59,4 +59,4 @@ async def request_inference_from_swin(model: namedtuple, previous_result: list[b return process_swin_result(previous_result.get("result_json"), results) except HTTPError as e: - raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None \ No newline at end of file + raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None diff --git a/tests/test_azure_storage_api.py b/tests/test_azure_storage_api.py index c2c1d894..a42d01d1 100644 --- a/tests/test_azure_storage_api.py +++ b/tests/test_azure_storage_api.py @@ -5,16 +5,39 @@ from azure_storage.azure_storage_api import ( mount_container, get_blob, - get_pipeline_info + get_pipeline_info, + get_blob_client ) from azure.core.exceptions import ResourceNotFoundError from custom_exceptions import ( GetBlobError, - PipelineNotFoundError + PipelineNotFoundError, + ConnectionStringError ) +class TestGetBlobServiceClient(unittest.TestCase): + @patch("azure.storage.blob.BlobServiceClient.from_connection_string") + def test_get_blob_service_successful(self, MockFromConnectionString): + mock_blob_service_client = MockFromConnectionString.return_value + + result = asyncio.run( + get_blob_client("connection_string") + ) + + print(result == mock_blob_service_client) + + self.assertEqual(result, mock_blob_service_client) + + @patch("azure.storage.blob.BlobServiceClient.from_connection_string") + def test_get_blob_service_unsuccessful(self, MockFromConnectionString): + MockFromConnectionString.return_value = None + + with self.assertRaises(ConnectionStringError) as context: + asyncio.run(get_blob_client("invalid_connection_string")) + + print(context.exception == "the given connection string is invalid: invalid_connection_string") class TestMountContainerFunction(unittest.TestCase): @patch("azure.storage.blob.BlobServiceClient.from_connection_string") @@ -27,14 +50,10 @@ def test_mount_existing_container(self, MockFromConnectionString): mock_blob_service_client.get_container_client.return_value = ( mock_container_client ) - - connection_string = "test_connection_string" container_name = "testcontainer" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete( - mount_container(connection_string, container_name) + result = asyncio.run( + mount_container(mock_blob_service_client, container_name) ) print(result == mock_container_client) @@ -62,14 +81,11 @@ def test_mount_nonexisting_container_create(self, MockFromConnectionString): mock_new_container_client ) - connection_string = "test_connection_string" container_name = "testcontainer" expected_container_name = "user-{}".format(container_name) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete( - mount_container(connection_string, container_name, create_container=True) + result = asyncio.run( + mount_container(mock_blob_service_client, container_name, create_container=True) ) mock_blob_service_client.create_container.assert_called_once_with( @@ -88,13 +104,10 @@ def test_mount_nonexisting_container_no_create(self, MockFromConnectionString): mock_container_client ) - connection_string = "test_connection_string" container_name = "testcontainer" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete( - mount_container(connection_string, container_name, create_container=False) + result = asyncio.run( + mount_container(mock_blob_service_client, container_name, create_container=False) ) mock_blob_service_client.create_container.assert_not_called() @@ -103,8 +116,7 @@ def test_mount_nonexisting_container_no_create(self, MockFromConnectionString): class TestGetBlob(unittest.TestCase): - @patch("azure.storage.blob.BlobServiceClient.from_connection_string") - def test_get_blob_successful(self, MockFromConnectionString): + def test_get_blob_successful(self): mock_blob_name = "test_blob" mock_blob_content = b"blob content" @@ -117,9 +129,7 @@ def test_get_blob_successful(self, MockFromConnectionString): mock_container_client = Mock() mock_container_client.get_blob_client.return_value = mock_blob_client - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete( + result = asyncio.run( get_blob(mock_container_client, mock_blob_name) ) @@ -166,26 +176,12 @@ def test_get_pipeline_info_successful(self, MockFromConnectionString,): mock_container_client ) - result = asyncio.run(get_pipeline_info("test_connection_string", "test_blob", "v1")) + result = asyncio.run(get_pipeline_info(mock_blob_service_client, "test_blob", "v1")) print(result == json.loads(mock_blob_content)) self.assertEqual(result, json.loads(mock_blob_content)) - @patch("azure.storage.blob.BlobServiceClient.from_connection_string") - def test_get_pipeline_info_wrong_connection_string(self, MockFromConnectionString): - - pipeline_version = "v1" - - MockFromConnectionString.side_effect = ( - ValueError("connection string is empty or not conform") - ) - - with self.assertRaises(PipelineNotFoundError) as context: - asyncio.run(get_pipeline_info("wrong_connection_string", "test_blob", pipeline_version)) - - print(str(context.exception) == f"This version {pipeline_version} was not found") - @patch("azure.storage.blob.BlobServiceClient.from_connection_string") def test_get_pipeline_info_unsuccessful(self, MockFromConnectionString): pipeline_version = "v1" @@ -202,7 +198,7 @@ def test_get_pipeline_info_unsuccessful(self, MockFromConnectionString): ) with self.assertRaises(PipelineNotFoundError) as context: - asyncio.run(get_pipeline_info("test_connection_string", "test_blob", pipeline_version)) + asyncio.run(get_pipeline_info(mock_blob_service_client, "test_blob", pipeline_version)) print(str(context.exception) == f"This version {pipeline_version} was not found") diff --git a/tests/test_inference_request.py b/tests/test_inference_request.py index 3b27f39a..fc2bf7e2 100644 --- a/tests/test_inference_request.py +++ b/tests/test_inference_request.py @@ -35,9 +35,8 @@ def tearDown(self) -> None: self.image_src = None self.test = None - @patch("azure.storage.blob.BlobServiceClient.from_connection_string") - def test_inference_request_successful(self, MockFromConnectionString): - + @patch("azure_storage.azure_storage_api.mount_container") + def test_inference_request_successful(self, mock_container): # Mock azure client services mock_blob = Mock() mock_blob.readall.return_value = bytes(self.image_src, encoding="utf-8") @@ -51,11 +50,7 @@ def test_inference_request_successful(self, MockFromConnectionString): mock_container_client.get_blob_client.return_value = mock_blob_client mock_container_client.exists.return_value = True - mock_blob_service_client = MockFromConnectionString.return_value - mock_blob_service_client.get_container_client.return_value = ( - mock_container_client - ) - + mock_container.return_value = mock_container_client # Build expected response keys responses = set() expected_keys = { @@ -96,8 +91,8 @@ def test_inference_request_successful(self, MockFromConnectionString): print(expected_keys == responses) self.assertEqual(responses, expected_keys) - @patch("azure.storage.blob.BlobServiceClient.from_connection_string") - def test_inference_request_unsuccessfull(self, MockFromConnectionString): + @patch("azure_storage.azure_storage_api.mount_container") + def test_inference_request_unsuccessfull(self, mock_container): # Mock azure client services mock_blob = Mock() mock_blob.readall.return_value = b"" @@ -111,10 +106,7 @@ def test_inference_request_unsuccessfull(self, MockFromConnectionString): mock_container_client.get_blob_client.return_value = mock_blob_client mock_container_client.exists.return_value = True - mock_blob_service_client = MockFromConnectionString.return_value - mock_blob_service_client.get_container_client.return_value = ( - mock_container_client - ) + mock_container.return_value = mock_container_client # Build expected response expected = 400