Skip to content

Commit

Permalink
fixes #51: Update model_utilitary_functions import
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxence Guindon committed Feb 12, 2024
1 parent 103e8c8 commit 85a62f4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 54 deletions.
92 changes: 43 additions & 49 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import azure_storage.azure_storage_api as azure_storage_api
import model_inference.inference as inference
import model_utilitary_functions as utils
import model_utilitary_functions.model_UTILS as utils
from custom_exceptions import (
DeleteDirectoryRequestError,
ListDirectoriesRequestError,
Expand All @@ -34,13 +34,13 @@
NACHET_DATA = os.getenv("NACHET_DATA")
NACHET_MODEL = os.getenv("NACHET_MODEL")

# The following tuples will be used to store the endpoints and their respective utilitary functions
tuple_endpoints = (((endpoint_url, endpoint_api_key, ""),),((sd_endpoint, sd_api_key, utils.image_slicing),(swin_endpoint, swin_api_key, "")))

CACHE = {
"seeds": None,
"endpoints": None,
"pipelines": {
"Seed Classification": ((endpoint_url, endpoint_api_key),),
"Swin": ((sd_endpoint, sd_api_key),(swin_endpoint, swin_api_key)) # swin
}
"pipelines": {}
}

# Check: do environment variables exist?
Expand Down Expand Up @@ -172,6 +172,9 @@ async def inference_request():
if not (folder_name and container_name and imageDims and image_base64):
return jsonify(["missing request arguments"]), 400

if not pipelines_endpoints.get(pipeline_name):
return jsonify([f"Model {pipeline_name} not found"]), 400

_, encoded_data = image_base64.split(",", 1)
image_bytes = base64.b64decode(encoded_data)
container_client = await azure_storage_api.mount_container(
Expand All @@ -184,62 +187,40 @@ async def inference_request():
blob = await azure_storage_api.get_blob(container_client, blob_name)
image_bytes = base64.b64encode(blob).decode("utf8")

data = {
"input_data": {
"columns": ["image"],
"index": [0],
"data": [image_bytes],
}
}
try:
for model in pipelines_endpoints.get(pipeline_name):

if not pipelines_endpoints.get(pipeline_name):
return jsonify([f"Model {pipeline_name} not found"]), 400
endpoint_url, endpoint_api_key, utilitary_function = model
model_name = endpoint_url.split("/")[2].split(".")[0]

# encode the data as json to be sent to the model endpoint
body = str.encode(json.dumps(data))

#============================================================#
try:
endpoint_url, endpoint_api_key = pipelines_endpoints.get(pipeline_name)[0]
# Select good header for pipelines
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + endpoint_api_key),
} if pipeline_name != "Swin" else await utils.seed_detector_header(endpoint_api_key)

# send the request to the model endpoint
req = urllib.request.Request(endpoint_url, body, headers)
# get the response from the model endpoint
response = urllib.request.urlopen(req)
result = response.read()
result_json = json.loads(result.decode("utf-8"))

print("Get seed detector result") # transform to logging
if model_name == "swin-endpoint":
break

if pipeline_name == "Swin":
sliced_images = utils.image_slicing(result_json[0]['boxes'])

# Second model call
second_endpoint, second_api_key = pipelines_endpoints.get(pipeline_name)[1]
req = await utils.request_factory(image_bytes, endpoint_url, endpoint_api_key)
response = urllib.request.urlopen(req)
result = response.read()
result_json = json.loads(result.decode("utf-8"))

if utilitary_function:
image_bytes = await utilitary_function(image_bytes, result_json[0]['boxes'])

if model_name == "swin-endpoint":
#second_endpoint, second_api_key, _ = pipelines_endpoints.get(pipeline_name)[1]
headers = await utils.swin_header(endpoint_api_key)

headers = utils.swin_request_constructor(second_endpoint, second_api_key, sliced_images)

# build a request to the endpoint sending cropped images
for idx, img_bytes in enumerate(sliced_images):
req = urllib.request.Request(second_endpoint, img_bytes, headers)
for idx, img_bytes in enumerate(image_bytes):
req = urllib.request.Request(endpoint_url, img_bytes, headers)
response = urllib.request.urlopen(req)
result = response.read()
classification = json.loads(result.decode("utf-8"))

with open("result.txt", "w+") as file:
file.write(str(classification) + "\n")

result_json[0]['boxes'][idx]['label'] = classification[0].get('label')
result_json[0]['boxes'][idx]['score'] = classification[0].get('score')

#=======================================================================#

# process the inference results
print("End of inference request") # Transform into logging
print("Process results") # Transform into logging

processed_result_json = await inference.process_inference_results(
result_json, imageDims
)
Expand Down Expand Up @@ -317,10 +298,23 @@ async def fetch_json(repo_URL, key, file_path):
result = response.read()
result_json = json.loads(result.decode("utf-8"))
CACHE[key] = result_json
# logic to build pipeline
if key == "endpoints":
endpoint_name = [v for k, v in result_json[0].items() if k == "endpoint_name"]
keys = [v for k, v in result_json[0].items() if k == "model_name"]

for i, t in enumerate(tuple_endpoints):
if i > len(endpoint_name):
break
if re.search(endpoint_name[i], t[i][0]):
CACHE["pipelines"][keys[i]] = t


except urllib.error.HTTPError as error:
return jsonify({"error": f"Failed to retrieve the JSON. \
HTTP Status Code: {error.code}"}), 400
except Exception as e:
print(str(e))
return jsonify({"error": str(e)}), 500

async def data_factory(**kwargs):
Expand Down
13 changes: 8 additions & 5 deletions model_utilitary_functions/model_UTILS.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async def image_slicing(image_bytes: bytes, boxes: list[dict]) -> list:

format = image.format

cropped_images = [_ for _ in boxes]
cropped_images = [bytes(0) for _ in boxes]

for i, box in enumerate(boxes):
topX = int(box['box']['topX'] * image.width)
Expand All @@ -24,12 +24,13 @@ async def image_slicing(image_bytes: bytes, boxes: list[dict]) -> list:
buffered = io.BytesIO()
img.save(buffered, format)

encoded_img = base64.b64encode(buffered.getvalue()).decode("utf8").split(",", 1)

cropped_images[i] = base64.b64decode(encoded_img)
cropped_images[i] = base64.b64encode(buffered.getvalue()) #.decode("utf8")

return cropped_images

async def swin_result_parser(result: dict) -> list:
pass

async def seed_detector_header(api_key: str) -> dict:
return {
"Content-Type": "application/json",
Expand All @@ -50,10 +51,12 @@ async def request_factory(img_bytes: bytes, endpoint_url: str, api_key: str) ->
Return a request for calling AzureML AI model
"""

model_name = endpoint_url.split("/")[2].split(".")[0]

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + api_key),
}
} if model_name != "seed-detector" else await seed_detector_header(api_key)

data = {
"input_data": {
Expand Down

0 comments on commit 85a62f4

Please sign in to comment.