Skip to content

Commit

Permalink
fixes #51: Add get_blob_client
Browse files Browse the repository at this point in the history
fixes #51: fix OEF

Issue #51: Correct typo
  • Loading branch information
Maxence Guindon committed Apr 8, 2024
1 parent 3783d9c commit 58b4d44
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 113 deletions.
31 changes: 13 additions & 18 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CreateDirectoryRequestError,
ServerError,
PipelineNotFoundError,
ConnectionStringError
)

load_dotenv()
Expand Down Expand Up @@ -65,12 +66,11 @@ async def delete_directory():
"""
try:
data = await request.get_json()
connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"]
container_name = data["container_name"]
folder_name = data["folder_name"]
if container_name and folder_name:
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=False
app.config["BLOB_CLIENT"], container_name, create_container=False
)
if container_client:
folder_uuid = await azure_storage_api.get_folder_uuid(
Expand Down Expand Up @@ -101,11 +101,10 @@ async def list_directories():
"""
try:
data = await request.get_json()
connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"]
container_name = data["container_name"]
if container_name:
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=True
app.config["BLOB_CLIENT"], container_name, create_container=True
)
response = await azure_storage_api.get_directories(container_client)
return jsonify(response), 200
Expand All @@ -124,12 +123,11 @@ async def create_directory():
"""
try:
data = await request.get_json()
connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"]
container_name = data["container_name"]
folder_name = data["folder_name"]
if container_name and folder_name:
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=False
app.config["BLOB_CLIENT"], container_name, create_container=False
)
response = await azure_storage_api.create_folder(
container_client, folder_name
Expand Down Expand Up @@ -164,6 +162,7 @@ async def inference_request():
image_base64 = data["image"]

pipelines_endpoints = CACHE.get("pipelines")
blob_service_client = app.config.get("BLOB_CLIENT")

if not (folder_name and container_name and imageDims and image_base64):
raise InferenceRequestError(
Expand All @@ -181,7 +180,7 @@ async def inference_request():
image_bytes = base64.b64decode(encoded_data)

container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=True
blob_service_client, container_name, create_container=True
)
hash_value = await azure_storage_api.generate_hash(image_bytes)
blob_name = await azure_storage_api.upload_image(
Expand Down Expand Up @@ -220,14 +219,10 @@ async def inference_request():
print(f"Took: {'{:10.4f}'.format(time.perf_counter() - seconds)} seconds")
return jsonify(processed_result_json), 200

except InferenceRequestError as error:
except (KeyError, InferenceRequestError) as error:
print(error)
return jsonify(["InferenceRequestError: " + error.args[0]]), 400

except Exception as error:
print(error)
return jsonify(["Unexpected error occured"]), 500


@app.get("/seed-data/<seed_name>")
async def get_seed_data(seed_name):
Expand Down Expand Up @@ -318,21 +313,21 @@ async def get_pipelines():
- list: A list of dictionaries representing the pipelines.
"""
try:
# TO DO instantiate Blob Service Client
result_json = await azure_storage_api.get_pipeline_info(connection_string, PIPELINE_BLOB_NAME, PIPELINE_VERSION)
app.config["BLOB_CLIENT"] = await azure_storage_api.get_blob_client(connection_string)
result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], PIPELINE_BLOB_NAME, PIPELINE_VERSION)
cipher_suite = Fernet(FERNET_KEY)
except PipelineNotFoundError as error:
except (ConnectionStringError, PipelineNotFoundError) as error:
print(error)
raise ServerError("Server errror: Pipelines were not found") from error
raise ServerError("server errror: could not retrieve the pipelines") from error


models = ()
for model in result_json.get("models"):
m = Model(
request_function.get(model.get("api_call_function")),
model.get("model_name"),
# To protect sensible data (API key and model endpoint), we crypt them when
# they are pushed into the blob storage. Once we retrieve the data here in the
# To protect sensible data (API key and model endpoint), we encrypt it when
# it's pushed into the blob storage. Once we retrieve the data here in the
# backend, we need to decrypt the byte format to recover the original
# data.
cipher_suite.decrypt(model.get("endpoint").encode()).decode(),
Expand Down
73 changes: 39 additions & 34 deletions azure_storage/azure_storage_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,49 @@ async def generate_hash(image):
except GenerateHashError as error:
print(error)

async def get_blob_client(connection_string: str):
"""
given a connection string, returns the blob client object
"""
try:
blob_service_client = BlobServiceClient.from_connection_string(
connection_string
)
if blob_service_client is None:
raise ValueError(f"the given connection string is invalid: {connection_string}")
return blob_service_client

except ValueError as error:
print(error)
raise ConnectionStringError(error.args[0]) from error

async def mount_container(connection_string, container_uuid, create_container=True):

async def mount_container(
blob_service_client: BlobServiceClient,
container_uuid: str,
create_container: bool =True):
"""
given a connection string and a container name, mounts the container and
given a connection string and a container uuid, mounts the container and
returns the container client as an object that can be used in other
functions. if a specified container doesnt exist, it creates one with the
provided uuid, if create_container is True
"""
try:
blob_service_client = BlobServiceClient.from_connection_string(
connection_string
)
if blob_service_client:
container_name = "user-{}".format(container_uuid)
container_client = blob_service_client.get_container_client(container_name)
if container_client.exists():
container_name = "user-{}".format(container_uuid)
container_client = blob_service_client.get_container_client(container_name)
if container_client.exists():
return container_client
elif create_container and not container_client.exists():
container_client = blob_service_client.create_container(container_name)
# create general directory for new user container
response = await create_folder(container_client, "General")
if response:
return container_client
elif create_container and not container_client.exists():
container_client = blob_service_client.create_container(container_name)
# create general directory for new user container
response = await create_folder(container_client, "General")
if response:
return container_client
else:
return False
else:
raise ConnectionStringError("Invalid connection string")

else:
raise MountContainerError(f"could not create general directory: {container_name}")
except MountContainerError as error:
print(error)
return False
raise


async def get_blob(container_client: ContainerClient, blob_name: str):
Expand Down Expand Up @@ -237,7 +249,7 @@ async def get_directories(container_client):
return []

async def get_pipeline_info(
connection_string: str,
blob_service_client: BlobServiceClient,
pipeline_container_name: str,
pipeline_version: str
) -> json:
Expand All @@ -246,10 +258,10 @@ async def get_pipeline_info(
provided parameters.
Args:
connection_string (str): The connection string for the Azure Blob
Storage. pipeline_container_name (str): The name of the container where
the pipeline files are stored. pipeline_version (str): The version of
the pipeline to retrieve.
blob_service_client (BlobServiceClient): The BlobServiceClient object
pipeline_container_name (str): The name of the container where
the pipeline files are stored.
pipeline_version (str): The version of the pipeline to retrieve.
Returns:
json: The pipeline information in JSON format.
Expand All @@ -259,13 +271,6 @@ async def get_pipeline_info(
found.
"""
try:
blob_service_client = BlobServiceClient.from_connection_string(
connection_string
)

if blob_service_client is None:
raise PipelineNotFoundError("No Blob Service Client found with the connection string.")

container_client = blob_service_client.get_container_client(
pipeline_container_name
)
Expand All @@ -274,5 +279,5 @@ async def get_pipeline_info(
pipeline = json.loads(blob)
return pipeline

except (ValueError, GetBlobError, PipelineNotFoundError) as error:
except GetBlobError as error:
raise PipelineNotFoundError(f"This version {pipeline_version} was not found") from error
2 changes: 1 addition & 1 deletion model/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

"""
This file contain the generic inference function that process the data at the end
This file contains the generic inference function that processes the data at the end
of a given pipeline.
"""

Expand Down
4 changes: 2 additions & 2 deletions model/seed_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This file contain the function to request the inference and process the data from
This file contains the function that requests the inference and processes the data from
the seed detector model.
"""

Expand Down Expand Up @@ -59,7 +59,7 @@ def process_image_slicing(image_bytes: bytes, result_json: dict) -> list:

async def request_inference_from_seed_detector(model: namedtuple, previous_result: str):
"""
Requests inference from the seed detector model using the provided previous result.
Requests inference from the seed detector model using the previously provided result.
Args:
model (namedtuple): The seed detector model.
Expand Down
4 changes: 2 additions & 2 deletions model/six_seeds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This file contain the function to request the inference and process the data from
This file contains the function that requests the inference and processes the data from
the nachet-6seeds model.
"""

Expand Down Expand Up @@ -48,4 +48,4 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul
return result_object

except HTTPError as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
6 changes: 3 additions & 3 deletions model/swin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file contain the function to request the inference and process the data from
the nachet-6seeds model.
This file contains the function that requests the inference and processes the data from
the swin model.
"""

import json
Expand Down Expand Up @@ -59,4 +59,4 @@ async def request_inference_from_swin(model: namedtuple, previous_result: list[b

return process_swin_result(previous_result.get("result_json"), results)
except HTTPError as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
Loading

0 comments on commit 58b4d44

Please sign in to comment.