diff --git a/model_utilitary_functions/model_UTILS.py b/model_utilitary_functions/model_UTILS.py index e49f8cff..420f8b4c 100644 --- a/model_utilitary_functions/model_UTILS.py +++ b/model_utilitary_functions/model_UTILS.py @@ -5,6 +5,25 @@ from urllib.request import Request async def image_slicing(image_bytes: bytes, result_json: dict) -> list: + """ + This function takes the image bytes and the result_json from the model and + returns a list of cropped images. + The result_json is expected to be in the following format: + { + "boxes": [ + { + "box": { + "topX": 0.0, + "topY": 0.0, + "bottomX": 0.0, + "bottomY": 0.0 + }, + "label": "string", + "score": 0.0 + } + ], + } + """ boxes = result_json[0]['boxes'] image_io_byte = io.BytesIO(base64.b64decode(image_bytes)) image_io_byte.seek(0) @@ -30,6 +49,14 @@ async def image_slicing(image_bytes: bytes, result_json: dict) -> list: return cropped_images async def swin_result_parser(img_box:dict, results: dict) -> list: + """ + Args: + img_box (dict): The image box containing the bounding boxes and labels. + results (dict): The results from the model containing the detected seeds. + + Returns: + list: The updated image box with modified labels and scores. + """ for i, result in enumerate(results): img_box[0]['boxes'][i]['label'] = result[0].get('label') img_box[0]['boxes'][i]['score'] = result[0].get('score') @@ -43,17 +70,17 @@ async def seed_detector_header(api_key: str) -> dict: "azureml-model-deployment": "seed-detector-1", } -async def swin_header(api_key: str) -> dict: - return { - "Content-Type": "application/json", - "Authorization": ("Bearer " + api_key), - } - # Eventually the goals would be to have a request factory that would return # a request for the specified models such as the following: async def request_factory(img_bytes: str | bytes, endpoint_url: str, api_key: str) -> Request: """ - Return a request for calling AzureML AI model + Args: + img_bytes (str | bytes): The image data as either a string or bytes. + endpoint_url (str): The URL of the AI model endpoint. + api_key (str): The API key for accessing the AI model. + + Returns: + Request: The request object for calling the AI model. """ model_name = endpoint_url.split("/")[2].split(".")[0]