Skip to content

Commit

Permalink
fixes #51: refactor exception
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxence Guindon committed Feb 8, 2024
1 parent a7d524f commit ba85544
Showing 1 changed file with 108 additions and 95 deletions.
203 changes: 108 additions & 95 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
sd_api_key = os.getenv("NACHET_SEED_DETECTOR_ACCESS_KEY")
swin_api_key = os.getenv("NACHET_SWIN_ACCESS_KEY")

endpoints = [[endpoint_url, endpoint_api_key],[sd_endpoint, sd_api_key],[swin_endpoint, swin_api_key]]
pipelines_endpoints = {
"legacy": (endpoint_url, endpoint_api_key),
"swin": ((sd_endpoint, sd_api_key), (swin_endpoint, swin_api_key))
}


NACHET_DATA = os.getenv("NACHET_DATA")
NACHET_MODEL = os.getenv("NACHET_MODEL")
Expand Down Expand Up @@ -152,123 +156,132 @@ async def create_directory():
@app.post("/inf")
async def inference_request():
"""
performs inference on an image, and returns the results.
The image and inference results uploaded to a folder in the user's container.
Performs inference on an image, and returns the results.
The image and inference results are uploaded to a folder in the user's container.
"""
try:
data = await request.get_json()
# connection_string: str = os.environ["NACHET_AZURE_STORAGE_CONNECTION_STRING"]
pipeline_name = data.get("model_name", "defaul_model")
folder_name = data["folder_name"]
container_name = data["container_name"]
imageDims = data["imageDims"]
image_base64 = data["image"]
if folder_name and container_name and imageDims and image_base64:
_, encoded_data = image_base64.split(",", 1)
image_bytes = base64.b64decode(encoded_data)
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=True
)
hash_value = await azure_storage_api.generate_hash(image_bytes)
blob_name = await azure_storage_api.upload_image(
container_client, folder_name, image_bytes, hash_value
)
blob = await azure_storage_api.get_blob(container_client, blob_name)
image_bytes = base64.b64encode(blob).decode("utf8")

data = {
"input_data": {
"columns": ["image"],
"index": [0],
"data": [image_bytes],
}
if not (folder_name and container_name and imageDims and image_base64):
return jsonify(["missing request arguments"]), 400

_, encoded_data = image_base64.split(",", 1)
image_bytes = base64.b64decode(encoded_data)
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=True
)
hash_value = await azure_storage_api.generate_hash(image_bytes)
blob_name = await azure_storage_api.upload_image(
container_client, folder_name, image_bytes, hash_value
)
blob = await azure_storage_api.get_blob(container_client, blob_name)
image_bytes = base64.b64encode(blob).decode("utf8")

data = {
"input_data": {
"columns": ["image"],
"index": [0],
"data": [image_bytes],
}
#============================================================#
# encode the data as json to be sent to the model endpoint
body = str.encode(json.dumps(data))

try:
endpoint_url, endpoint_api_key = endpoints[1]
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + endpoint_api_key),
'azureml-model-deployment': 'seed-detector-1'
}
# send the request to the model endpoint
req = urllib.request.Request(endpoint_url, body, headers)
# get the response from the model endpoint
response = urllib.request.urlopen(req)
result = response.read()
result_json = json.loads(result.decode("utf-8"))
}

if not pipelines_endpoints.get(pipeline_name):
return jsonify([f"Model {pipeline_name} not found"]), 400

#============================================================#
# encode the data as json to be sent to the model endpoint
body = str.encode(json.dumps(data))

try:
endpoint_url, endpoint_api_key = pipelines_endpoints.get("swin")[0]
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + endpoint_api_key),
'azureml-model-deployment': 'seed-detector-1'
}
# send the request to the model endpoint
req = urllib.request.Request(endpoint_url, body, headers)
# get the response from the model endpoint
response = urllib.request.urlopen(req)
result = response.read()
result_json = json.loads(result.decode("utf-8"))

# Cropping image and feed them to the next model
# Cropping image and feed them to the next model

image_io_byte = io.BytesIO(base64.b64decode(image_bytes))
image_io_byte.seek(0)
image = Image.open(image_io_byte)
image_io_byte = io.BytesIO(base64.b64decode(image_bytes))
image_io_byte.seek(0)
image = Image.open(image_io_byte)

format = image.format
format = image.format

boxes = result_json[0]['boxes']
boxes = result_json[0]['boxes']

cropped_images = [bytes(0) for _ in boxes]
cropped_images = [bytes(0) for _ in boxes]

for i, box in enumerate(boxes):
topX = int(box['box']['topX'] * image.width)
topY = int(box['box']['topY'] * image.height)
bottomX = int(box['box']['bottomX'] * image.width)
bottomY = int(box['box']['bottomY'] * image.height)
for i, box in enumerate(boxes):
topX = int(box['box']['topX'] * image.width)
topY = int(box['box']['topY'] * image.height)
bottomX = int(box['box']['bottomX'] * image.width)
bottomY = int(box['box']['bottomY'] * image.height)

buffered = io.BytesIO()
img = image.crop((topX, topY, bottomX, bottomY))

img.save(buffered, format)
cropped_images[i] = base64.b64encode(buffered.getvalue())

# Second model call

endpoint, api_key = endpoints[2]
buffered = io.BytesIO()
img = image.crop((topX, topY, bottomX, bottomY))

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + api_key),
}

for idx, img_bytes in enumerate(cropped_images):
req = urllib.request.Request(endpoint, img_bytes, headers)

response = urllib.request.urlopen(req)
result = response.read()
classification = json.loads(result.decode("utf-8"))
result_json[0]['boxes'][idx]['label'] = classification[0].get('label')
result_json[0]['boxes'][idx]['score'] = classification[0].get('score')
img.save(buffered, format)
cropped_images[i] = base64.b64encode(buffered.getvalue())

# Second model call

#=======================================================================#
endpoint, api_key = pipelines_endpoints.get("swin")[1]

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + api_key),
}

# process the inference results
processed_result_json = await inference.process_inference_results(
result_json, imageDims
)
# upload the inference results to the user's container as async task
result_json_string = json.dumps(processed_result_json)
app.add_background_task(
azure_storage_api.upload_inference_result,
container_client,
folder_name,
result_json_string,
hash_value,
)
# return the inference results to the client
return jsonify(processed_result_json), 200
for idx, img_bytes in enumerate(cropped_images):
req = urllib.request.Request(endpoint, img_bytes, headers)

response = urllib.request.urlopen(req)
result = response.read()
classification = json.loads(result.decode("utf-8"))
result_json[0]['boxes'][idx]['label'] = classification[0].get('label')
result_json[0]['boxes'][idx]['score'] = classification[0].get('score')

#=======================================================================#

# process the inference results
processed_result_json = await inference.process_inference_results(
result_json, imageDims
)
except urllib.error.HTTPError as error:
print(error)
return jsonify(["endpoint cannot be reached" + str(error.code)]), 400

# upload the inference results to the user's container as async task
result_json_string = json.dumps(processed_result_json)
app.add_background_task(
azure_storage_api.upload_inference_result,
container_client,
folder_name,
result_json_string,
hash_value,
)
# return the inference results to the client
return jsonify(processed_result_json), 200

except urllib.error.HTTPError as error:
print(error)
return jsonify(["endpoint cannot be reached" + str(error.code)]), 400
else:
return jsonify(["missing request arguments"]), 400

except InferenceRequestError as error:
print(error)
return jsonify(["InferenceRequestError: " + str(error)]), 400

except Exception as error:
print(error)
return jsonify(["Unexpected error occured"]), 500


@app.get("/seed-data/<seed_name>")
Expand Down

0 comments on commit ba85544

Please sign in to comment.