Skip to content

Commit

Permalink
fixes #51: change from loop to asyncio.run
Browse files Browse the repository at this point in the history
fixes #51: Correct trailing whitespace and EOF.
  • Loading branch information
MaxenceGui committed Apr 8, 2024
1 parent e340108 commit 471bcc5
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 44 deletions.
29 changes: 14 additions & 15 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ async def inference_request():
"""
Performs inference on an image, and returns the results.
The image and inference results are uploaded to a folder in the user's container.
"""
"""

seconds = time.perf_counter() # transform into logging
try:
print("Entering inference request") # Transform into logging
Expand All @@ -164,16 +164,16 @@ async def inference_request():

if not (folder_name and container_name and imageDims and image_base64):
return jsonify(["missing request arguments"]), 400

if not pipelines_endpoints.get(pipeline_name):
return jsonify([f"Model {pipeline_name} not found"]), 400

header, encoded_data = image_base64.split(",", 1)

# Validate image header
if not header.startswith("data:image/"):
return jsonify(["Invalid image header"]), 400

image_bytes = base64.b64decode(encoded_data)
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=True
Expand Down Expand Up @@ -205,7 +205,7 @@ async def inference_request():
except urllib.error.HTTPError as error:
print(error)
return jsonify(["endpoint cannot be reached" + str(error.code)]), 400

# upload the inference results to the user's container as async task
result_json_string = json.dumps(processed_result_json)

Expand All @@ -218,12 +218,12 @@ async def inference_request():
)
# return the inference results to the client
print(f"Took: {'{:10.4f}'.format(time.perf_counter() - seconds)} seconds")
return jsonify(processed_result_json), 200
return jsonify(processed_result_json), 200

except InferenceRequestError as error:
print(error)
return jsonify(["InferenceRequestError: " + str(error)]), 400

except Exception as error:
print(error)
return jsonify(["Unexpected error occured"]), 500
Expand All @@ -234,11 +234,11 @@ async def get_seed_data(seed_name):
"""
Returns JSON containing requested seed data
"""
if seed_name in CACHE['seeds']:
if seed_name in CACHE['seeds']:
return jsonify(CACHE['seeds'][seed_name]), 200
else:
return jsonify(f"No information found for {seed_name}."), 400


@app.get("/reload-seed-data")
async def reload_seed_data():
Expand All @@ -257,7 +257,7 @@ async def get_model_endpoints_metadata():
"""
Returns JSON containing the deployed endpoints' metadata
"""
if CACHE['endpoints']:
if CACHE['endpoints']:
return jsonify(CACHE['endpoints']), 200
else:
return jsonify("Error retrieving model endpoints metadata.", 400)
Expand Down Expand Up @@ -329,7 +329,7 @@ async def get_pipelines(mock:bool = False):
cipher_suite = Fernet(FERNET_KEY)
# Get all the api_call function and map them in a dictionary
api_call_function = {func.split("from_")[1]: getattr(model_module, func) for func in dir(model_module) if "inference" in func.split("_")}
# Get all the inference functions and map them in a dictionary
# Get all the inference functions and map them in a dictionary
inference_functions = {func: getattr(inference, func) for func in dir(inference) if "process" in func.split("_")}

models = ()
Expand Down Expand Up @@ -363,11 +363,11 @@ async def before_serving():
# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if FERNET_KEY is None:
raise ServerError("Missing environment variable: FERNET_KEY")

# Check: are environment variables correct?
# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

Expand All @@ -385,4 +385,3 @@ async def before_serving():

if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=8080)

10 changes: 5 additions & 5 deletions azure_storage/azure_storage_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ async def get_pipeline_info(
container_client = blob_service_client.get_container_client(
pipeline_container_name
)

blob_list = container_client.list_blobs()
for blob in blob_list:
if blob.name.split(".")[-1] != "json":
Expand All @@ -280,19 +280,19 @@ async def get_pipeline_info(
raise PipelineNotFoundError(
"This version of the pipeline was not found."
)

except PipelineNotFoundError as error:
print(error)
return False

except FolderListError as error:
print(error)
return False

def insert_new_version_pipeline(
pipelines_json: dict,
connection_string: str,
pipleine_container_name: str
pipeline_container_name: str
) -> bool:
"""
Inserts a new version of a pipeline JSON into an Azure Blob Storage container.
Expand All @@ -312,7 +312,7 @@ def insert_new_version_pipeline(

if blob_service_client:
container_client = blob_service_client.get_container_client(
pipleine_container_name
pipeline_container_name
)

json_name = "{}/{}.json".format("pipelines", pipelines_json.get("version"))
Expand Down
10 changes: 5 additions & 5 deletions docs/nachet-inference-documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ sequenceDiagram
Backend-)Frontend: error 500 Failed to retrieve data from the repository
end
Note over Backend,Blob storage: end of initialisation
Client->>+Frontend: applicationStart()
Frontend-)Backend: HTTP POST req.
Backend-)Backend: get_model_endpoints_metadata()
Expand All @@ -103,7 +103,7 @@ sequenceDiagram
Backend-)Backend: mount_container(connection_string(Environnement Variable, container_name))
Backend-)+Blob storage: HTTP POST req.
Blob storage--)-Backend: container_client
Backend-)Backend: Generate Hash(image_bytes)
Backend-)Backend: upload_image(container_client, folder_name, image_bytes, hash_value)
Expand Down Expand Up @@ -260,7 +260,7 @@ async def get_pipeline(mock:bool = False):
cipher_suite = Fernet(FERNET_KEY)
# Get all the api_call function and map them in a dictionary
api_call_function = {func.split("from_")[1]: getattr(model_module, func) for func in dir(model_module) if "inference" in func.split("_")}
# Get all the inference functions and map them in a dictionary
# Get all the inference functions and map them in a dictionary
inference_functions = {func: getattr(inference, func) for func in dir(inference) if "process" in func.split("_")}

models = ()
Expand All @@ -275,12 +275,12 @@ async def get_pipeline(mock:bool = False):
model.get("deployment_platform")
)
models += (m,)

# Build the pipeline to call the models in order in the inference request
for pipeline in result_json.get("pipelines"):
CACHE["pipelines"][pipeline.get("pipeline_name")] = tuple([m for m in models if m.name in pipeline.get("models")])

return result_json.get("pipelines")
return result_json.get("pipelines")

```

Expand Down
4 changes: 2 additions & 2 deletions model_inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def process_image_slicing(image_bytes: bytes, result_json: dict) -> list:
img.save(buffered, format)

cropped_images[i] = base64.b64encode(buffered.getvalue()) #.decode("utf8")

return cropped_images

async def process_swin_result(img_box:dict, results: dict) -> list:
Expand All @@ -61,7 +61,7 @@ async def process_swin_result(img_box:dict, results: dict) -> list:
img_box[0]['boxes'][i]['label'] = result[0].get("label")
img_box[0]['boxes'][i]['score'] = result[0].get("score")
img_box[0]['boxes'][i]["topN"] = [d for d in result]

return img_box

async def process_inference_results(data, imageDims):
Expand Down
6 changes: 3 additions & 3 deletions model_inference/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul
Returns:
dict: A dictionary containing the result JSON and the images generated from the inference.
Raises:
InferenceRequestError: If an error occurs while processing the request.
"""
Expand All @@ -67,7 +67,7 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul
}
except Exception as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}")


async def request_inference_from_nachet_6seeds(model: namedtuple, previous_result: str):
"""
Expand All @@ -93,7 +93,7 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul

except Exception as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}")

async def request_inference_from_test(model: namedtuple, previous_result: str):
"""
Requests a test case inference.
Expand Down
4 changes: 2 additions & 2 deletions model_request/model_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ async def request_factory(img_bytes: str | bytes, model: namedtuple) -> Request:
if deployment_platform in supported_deployment_platform:
headers[model.deployment_platform[deployment_platform]] = model.name

if isinstance(img_bytes, str):
if isinstance(img_bytes, str):
data = {
"input_data": {
"columns": ["image"],
"index": [0],
"data": [img_bytes],
}
}
}
body = str.encode(json.dumps(data))
elif isinstance(img_bytes, bytes):
body = img_bytes
Expand Down
2 changes: 1 addition & 1 deletion tests/test_azure_storage_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_get_pipeline_info_successful(self, MockFromConnectionString,):
mock_blob_service_client.get_container_client.return_value = (
mock_container_client
)

connection_string = "test_connection_string"
mock_blob_name = "test_blob"
mock_version = "v1"
Expand Down
20 changes: 9 additions & 11 deletions tests/test_inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ def setUp(self) -> None:
Set up the test environment before running each test case.
"""
# Start the test pipeline
self.loop = asyncio.get_event_loop()
self.test = app.test_client()
response = self.loop.run_until_complete(
response = asyncio.run(
self.test.get("/test")
)
self.pipeline = json.loads(asyncio.run(response.get_data()))[0]
Expand All @@ -38,7 +37,7 @@ def tearDown(self) -> None:

@patch("azure.storage.blob.BlobServiceClient.from_connection_string")
def test_inference_request_successful(self, MockFromConnectionString):

# Mock azure client services
mock_blob = Mock()
mock_blob.readall.return_value = bytes(self.image_src, encoding="utf-8")
Expand Down Expand Up @@ -73,7 +72,7 @@ def test_inference_request_successful(self, MockFromConnectionString):
}

# Test the answers from inference_request
response = self.loop.run_until_complete(
response = asyncio.run(
self.test.post(
'/inf',
headers={
Expand Down Expand Up @@ -121,7 +120,7 @@ def test_inference_request_unsuccessfull(self, MockFromConnectionString):
expected = 500

# Test the answers from inference_request
response = self.loop.run_until_complete(
response = asyncio.run(
self.test.post(
'/inf',
headers={
Expand Down Expand Up @@ -157,8 +156,8 @@ def test_inference_request_missing_argument(self):

for k, v in data.items():
if k != "model_name":
data[k] = ""
response = self.loop.run_until_complete(
data[k] = ""
response = asyncio.run(
self.test.post(
'/inf',
headers={
Expand All @@ -177,7 +176,7 @@ def test_inference_request_missing_argument(self):

if len(responses) > 1:
raise ValueError(f"Different errors messages were given; expected only 'missing request arguments', {responses}")

print(expected == result_json[0])
print(response.status_code == 400)
self.assertEqual(result_json[0], expected)
Expand All @@ -188,7 +187,7 @@ def test_inference_request_wrong_pipeline_name(self):
expected = ("Model wrong_pipeline_name not found")

# Test the answers from inference_request
response = self.loop.run_until_complete(
response = asyncio.run(
self.test.post(
'/inf',
headers={
Expand Down Expand Up @@ -217,7 +216,7 @@ def test_inference_request_wrong_header(self):
expected = ("Invalid image header")

# Test the answers from inference_request
response = self.loop.run_until_complete(
response = asyncio.run(
self.test.post(
'/inf',
headers={
Expand All @@ -243,4 +242,3 @@ def test_inference_request_wrong_header(self):

if __name__ == '__main__':
unittest.main()

0 comments on commit 471bcc5

Please sign in to comment.