diff --git a/app.py b/app.py index 973cc156..0da9a6a4 100644 --- a/app.py +++ b/app.py @@ -148,8 +148,8 @@ async def inference_request(): """ Performs inference on an image, and returns the results. The image and inference results are uploaded to a folder in the user's container. - """ - + """ + seconds = time.perf_counter() # transform into logging try: print("Entering inference request") # Transform into logging @@ -164,16 +164,16 @@ async def inference_request(): if not (folder_name and container_name and imageDims and image_base64): return jsonify(["missing request arguments"]), 400 - + if not pipelines_endpoints.get(pipeline_name): return jsonify([f"Model {pipeline_name} not found"]), 400 - + header, encoded_data = image_base64.split(",", 1) # Validate image header if not header.startswith("data:image/"): return jsonify(["Invalid image header"]), 400 - + image_bytes = base64.b64decode(encoded_data) container_client = await azure_storage_api.mount_container( connection_string, container_name, create_container=True @@ -205,7 +205,7 @@ async def inference_request(): except urllib.error.HTTPError as error: print(error) return jsonify(["endpoint cannot be reached" + str(error.code)]), 400 - + # upload the inference results to the user's container as async task result_json_string = json.dumps(processed_result_json) @@ -218,12 +218,12 @@ async def inference_request(): ) # return the inference results to the client print(f"Took: {'{:10.4f}'.format(time.perf_counter() - seconds)} seconds") - return jsonify(processed_result_json), 200 + return jsonify(processed_result_json), 200 except InferenceRequestError as error: print(error) return jsonify(["InferenceRequestError: " + str(error)]), 400 - + except Exception as error: print(error) return jsonify(["Unexpected error occured"]), 500 @@ -234,11 +234,11 @@ async def get_seed_data(seed_name): """ Returns JSON containing requested seed data """ - if seed_name in CACHE['seeds']: + if seed_name in CACHE['seeds']: return jsonify(CACHE['seeds'][seed_name]), 200 else: return jsonify(f"No information found for {seed_name}."), 400 - + @app.get("/reload-seed-data") async def reload_seed_data(): @@ -257,7 +257,7 @@ async def get_model_endpoints_metadata(): """ Returns JSON containing the deployed endpoints' metadata """ - if CACHE['endpoints']: + if CACHE['endpoints']: return jsonify(CACHE['endpoints']), 200 else: return jsonify("Error retrieving model endpoints metadata.", 400) @@ -329,7 +329,7 @@ async def get_pipelines(mock:bool = False): cipher_suite = Fernet(FERNET_KEY) # Get all the api_call function and map them in a dictionary api_call_function = {func.split("from_")[1]: getattr(model_module, func) for func in dir(model_module) if "inference" in func.split("_")} - # Get all the inference functions and map them in a dictionary + # Get all the inference functions and map them in a dictionary inference_functions = {func: getattr(inference, func) for func in dir(inference) if "process" in func.split("_")} models = () @@ -363,11 +363,11 @@ async def before_serving(): # Check: do environment variables exist? if connection_string is None: raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING") - + if FERNET_KEY is None: raise ServerError("Missing environment variable: FERNET_KEY") - # Check: are environment variables correct? + # Check: are environment variables correct? if not bool(re.match(connection_string_regex, connection_string)): raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING") @@ -385,4 +385,3 @@ async def before_serving(): if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=8080) - \ No newline at end of file diff --git a/azure_storage/azure_storage_api.py b/azure_storage/azure_storage_api.py index d6dfa835..fa8b449c 100644 --- a/azure_storage/azure_storage_api.py +++ b/azure_storage/azure_storage_api.py @@ -265,7 +265,7 @@ async def get_pipeline_info( container_client = blob_service_client.get_container_client( pipeline_container_name ) - + blob_list = container_client.list_blobs() for blob in blob_list: if blob.name.split(".")[-1] != "json": @@ -280,11 +280,11 @@ async def get_pipeline_info( raise PipelineNotFoundError( "This version of the pipeline was not found." ) - + except PipelineNotFoundError as error: print(error) return False - + except FolderListError as error: print(error) return False @@ -292,7 +292,7 @@ async def get_pipeline_info( def insert_new_version_pipeline( pipelines_json: dict, connection_string: str, - pipleine_container_name: str + pipeline_container_name: str ) -> bool: """ Inserts a new version of a pipeline JSON into an Azure Blob Storage container. @@ -312,7 +312,7 @@ def insert_new_version_pipeline( if blob_service_client: container_client = blob_service_client.get_container_client( - pipleine_container_name + pipeline_container_name ) json_name = "{}/{}.json".format("pipelines", pipelines_json.get("version")) diff --git a/docs/nachet-inference-documentation.md b/docs/nachet-inference-documentation.md index e4e4519e..7c6111dd 100644 --- a/docs/nachet-inference-documentation.md +++ b/docs/nachet-inference-documentation.md @@ -79,7 +79,7 @@ sequenceDiagram Backend-)Frontend: error 500 Failed to retrieve data from the repository end Note over Backend,Blob storage: end of initialisation - + Client->>+Frontend: applicationStart() Frontend-)Backend: HTTP POST req. Backend-)Backend: get_model_endpoints_metadata() @@ -103,7 +103,7 @@ sequenceDiagram Backend-)Backend: mount_container(connection_string(Environnement Variable, container_name)) Backend-)+Blob storage: HTTP POST req. Blob storage--)-Backend: container_client - + Backend-)Backend: Generate Hash(image_bytes) Backend-)Backend: upload_image(container_client, folder_name, image_bytes, hash_value) @@ -260,7 +260,7 @@ async def get_pipeline(mock:bool = False): cipher_suite = Fernet(FERNET_KEY) # Get all the api_call function and map them in a dictionary api_call_function = {func.split("from_")[1]: getattr(model_module, func) for func in dir(model_module) if "inference" in func.split("_")} - # Get all the inference functions and map them in a dictionary + # Get all the inference functions and map them in a dictionary inference_functions = {func: getattr(inference, func) for func in dir(inference) if "process" in func.split("_")} models = () @@ -275,12 +275,12 @@ async def get_pipeline(mock:bool = False): model.get("deployment_platform") ) models += (m,) - + # Build the pipeline to call the models in order in the inference request for pipeline in result_json.get("pipelines"): CACHE["pipelines"][pipeline.get("pipeline_name")] = tuple([m for m in models if m.name in pipeline.get("models")]) - return result_json.get("pipelines") + return result_json.get("pipelines") ``` diff --git a/model_inference/inference.py b/model_inference/inference.py index ba2f6fff..9230ba5c 100644 --- a/model_inference/inference.py +++ b/model_inference/inference.py @@ -45,7 +45,7 @@ async def process_image_slicing(image_bytes: bytes, result_json: dict) -> list: img.save(buffered, format) cropped_images[i] = base64.b64encode(buffered.getvalue()) #.decode("utf8") - + return cropped_images async def process_swin_result(img_box:dict, results: dict) -> list: @@ -61,7 +61,7 @@ async def process_swin_result(img_box:dict, results: dict) -> list: img_box[0]['boxes'][i]['label'] = result[0].get("label") img_box[0]['boxes'][i]['score'] = result[0].get("score") img_box[0]['boxes'][i]["topN"] = [d for d in result] - + return img_box async def process_inference_results(data, imageDims): diff --git a/model_inference/model_module.py b/model_inference/model_module.py index 3cd9a0c5..fdb70161 100644 --- a/model_inference/model_module.py +++ b/model_inference/model_module.py @@ -51,7 +51,7 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul Returns: dict: A dictionary containing the result JSON and the images generated from the inference. - + Raises: InferenceRequestError: If an error occurs while processing the request. """ @@ -67,7 +67,7 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul } except Exception as e: raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") - + async def request_inference_from_nachet_6seeds(model: namedtuple, previous_result: str): """ @@ -93,7 +93,7 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul except Exception as e: raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") - + async def request_inference_from_test(model: namedtuple, previous_result: str): """ Requests a test case inference. diff --git a/model_request/model_request.py b/model_request/model_request.py index 6ecd6097..3d62708d 100644 --- a/model_request/model_request.py +++ b/model_request/model_request.py @@ -24,14 +24,14 @@ async def request_factory(img_bytes: str | bytes, model: namedtuple) -> Request: if deployment_platform in supported_deployment_platform: headers[model.deployment_platform[deployment_platform]] = model.name - if isinstance(img_bytes, str): + if isinstance(img_bytes, str): data = { "input_data": { "columns": ["image"], "index": [0], "data": [img_bytes], } - } + } body = str.encode(json.dumps(data)) elif isinstance(img_bytes, bytes): body = img_bytes diff --git a/tests/test_azure_storage_api.py b/tests/test_azure_storage_api.py index 58d416db..4f6d7277 100644 --- a/tests/test_azure_storage_api.py +++ b/tests/test_azure_storage_api.py @@ -170,7 +170,7 @@ def test_get_pipeline_info_successful(self, MockFromConnectionString,): mock_blob_service_client.get_container_client.return_value = ( mock_container_client ) - + connection_string = "test_connection_string" mock_blob_name = "test_blob" mock_version = "v1" diff --git a/tests/test_inference_request.py b/tests/test_inference_request.py index c3acf12b..06e20d59 100644 --- a/tests/test_inference_request.py +++ b/tests/test_inference_request.py @@ -13,9 +13,8 @@ def setUp(self) -> None: Set up the test environment before running each test case. """ # Start the test pipeline - self.loop = asyncio.get_event_loop() self.test = app.test_client() - response = self.loop.run_until_complete( + response = asyncio.run( self.test.get("/test") ) self.pipeline = json.loads(asyncio.run(response.get_data()))[0] @@ -38,7 +37,7 @@ def tearDown(self) -> None: @patch("azure.storage.blob.BlobServiceClient.from_connection_string") def test_inference_request_successful(self, MockFromConnectionString): - + # Mock azure client services mock_blob = Mock() mock_blob.readall.return_value = bytes(self.image_src, encoding="utf-8") @@ -73,7 +72,7 @@ def test_inference_request_successful(self, MockFromConnectionString): } # Test the answers from inference_request - response = self.loop.run_until_complete( + response = asyncio.run( self.test.post( '/inf', headers={ @@ -121,7 +120,7 @@ def test_inference_request_unsuccessfull(self, MockFromConnectionString): expected = 500 # Test the answers from inference_request - response = self.loop.run_until_complete( + response = asyncio.run( self.test.post( '/inf', headers={ @@ -157,8 +156,8 @@ def test_inference_request_missing_argument(self): for k, v in data.items(): if k != "model_name": - data[k] = "" - response = self.loop.run_until_complete( + data[k] = "" + response = asyncio.run( self.test.post( '/inf', headers={ @@ -177,7 +176,7 @@ def test_inference_request_missing_argument(self): if len(responses) > 1: raise ValueError(f"Different errors messages were given; expected only 'missing request arguments', {responses}") - + print(expected == result_json[0]) print(response.status_code == 400) self.assertEqual(result_json[0], expected) @@ -188,7 +187,7 @@ def test_inference_request_wrong_pipeline_name(self): expected = ("Model wrong_pipeline_name not found") # Test the answers from inference_request - response = self.loop.run_until_complete( + response = asyncio.run( self.test.post( '/inf', headers={ @@ -217,7 +216,7 @@ def test_inference_request_wrong_header(self): expected = ("Invalid image header") # Test the answers from inference_request - response = self.loop.run_until_complete( + response = asyncio.run( self.test.post( '/inf', headers={ @@ -243,4 +242,3 @@ def test_inference_request_wrong_header(self): if __name__ == '__main__': unittest.main() - \ No newline at end of file