Skip to content

Commit

Permalink
fixes #51: update inference_request to call swin in loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxence Guindon committed Feb 13, 2024
1 parent b361f31 commit ffc6cb8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 36 deletions.
46 changes: 19 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
NACHET_MODEL = os.getenv("NACHET_MODEL")

# The following tuples will be used to store the endpoints and their respective utilitary functions
tuple_endpoints = (((endpoint_url, endpoint_api_key, ""),),((sd_endpoint, sd_api_key, utils.image_slicing),(swin_endpoint, swin_api_key, "")))
tuple_endpoints = (((endpoint_url, endpoint_api_key, ""),),((sd_endpoint, sd_api_key, utils.image_slicing),(swin_endpoint, swin_api_key, utils.swin_result_parser)))

CACHE = {
"seeds": None,
Expand Down Expand Up @@ -188,36 +188,31 @@ async def inference_request():
image_bytes = base64.b64encode(blob).decode("utf8")

try:
cache_json_result = None
for model in pipelines_endpoints.get(pipeline_name):

endpoint_url, endpoint_api_key, utilitary_function = model
model_name = endpoint_url.split("/")[2].split(".")[0]

if model_name == "swin-endpoint":
break
if isinstance(image_bytes, list):
result_json = []
for img in image_bytes:
req = await utils.request_factory(img, endpoint_url, endpoint_api_key)
response = urllib.request.urlopen(req)
result = response.read()
result_json.append(json.loads(result.decode("utf-8")))

req = await utils.request_factory(image_bytes, endpoint_url, endpoint_api_key)
response = urllib.request.urlopen(req)
result = response.read()
result_json = json.loads(result.decode("utf-8"))

if utilitary_function:
image_bytes = await utilitary_function(image_bytes, result_json)
# cache_json_result = result_json

if model_name == "swin-endpoint":
#second_endpoint, second_api_key, _ = pipelines_endpoints.get(pipeline_name)[1]
headers = await utils.swin_header(endpoint_api_key)

# build a request to the endpoint sending cropped images
all_classification = []
for img_bytes in image_bytes:
req = urllib.request.Request(endpoint_url, img_bytes, headers)
elif isinstance(image_bytes, str):
req = await utils.request_factory(image_bytes, endpoint_url, endpoint_api_key)
response = urllib.request.urlopen(req)
result = response.read()
all_classification.append(json.loads(result.decode("utf-8")))
result_json = json.loads(result.decode("utf-8"))

result_json = await utils.swin_result_parser(result_json, all_classification)
if utilitary_function:
if isinstance(image_bytes, str):
image_bytes = await utilitary_function(image_bytes, result_json)
elif isinstance(image_bytes, list):
result_json = await utilitary_function(cache_json_result, result_json)
cache_json_result = result_json

print("End of inference request") # Transform into logging
print("Process results") # Transform into logging
Expand Down Expand Up @@ -323,7 +318,6 @@ async def data_factory(**kwargs):
"input_data": kwargs,
}


@app.before_serving
async def before_serving():
await fetch_json(NACHET_DATA, 'seeds', "seeds/all.json")
Expand All @@ -342,8 +336,6 @@ async def before_serving():
"identifiable": []
})



if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=8080)


21 changes: 12 additions & 9 deletions model_utilitary_functions/model_UTILS.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def swin_header(api_key: str) -> dict:

# Eventually the goals would be to have a request factory that would return
# a request for the specified models such as the following:
async def request_factory(img_bytes: bytes, endpoint_url: str, api_key: str) -> Request:
async def request_factory(img_bytes: str | bytes, endpoint_url: str, api_key: str) -> Request:
"""
Return a request for calling AzureML AI model
"""
Expand All @@ -63,14 +63,17 @@ async def request_factory(img_bytes: bytes, endpoint_url: str, api_key: str) ->
"Authorization": ("Bearer " + api_key),
} if model_name != "seed-detector" else await seed_detector_header(api_key)

data = {
"input_data": {
"columns": ["image"],
"index": [0],
"data": [img_bytes],
}
}

body = str.encode(json.dumps(data))
if isinstance(img_bytes, str):
data = {
"input_data": {
"columns": ["image"],
"index": [0],
"data": [img_bytes],
}
}
body = str.encode(json.dumps(data))
elif isinstance(img_bytes, bytes):
body = img_bytes

return Request(endpoint_url, body, headers)

0 comments on commit ffc6cb8

Please sign in to comment.