Skip to content

Commit

Permalink
merge refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvanie85 committed May 22, 2024
2 parents 3a2ba7f + 4d641f9 commit c30f7e1
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "pip3 install --user -r requirements.txt",
"postCreateCommand": "pip3 install --user -r requirements.txt && pip install --upgrade pydantic",

// Configure tool-specific properties.
"customizations": {
Expand Down
35 changes: 18 additions & 17 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import tempfile

from PIL import Image, UnidentifiedImageError
from PIL import Image
from datetime import date
from dotenv import load_dotenv
from quart import Quart, request, jsonify
Expand All @@ -19,6 +19,7 @@

load_dotenv()

from azure.core.exceptions import ResourceNotFoundError, ServiceResponseError

Check failure on line 22 in app.py

View workflow job for this annotation

GitHub Actions / lint-test / lint-test

Ruff (E402)

app.py:22:1: E402 Module level import not at top of file
import model.inference as inference

Check failure on line 23 in app.py

View workflow job for this annotation

GitHub Actions / lint-test / lint-test

Ruff (E402)

app.py:23:1: E402 Module level import not at top of file
from model import request_function

Check failure on line 24 in app.py

View workflow job for this annotation

GitHub Actions / lint-test / lint-test

Ruff (E402)

app.py:24:1: E402 Module level import not at top of file
from datastore import azure_storage_api

Check failure on line 25 in app.py

View workflow job for this annotation

GitHub Actions / lint-test / lint-test

Ruff (E402)

app.py:25:1: E402 Module level import not at top of file
Expand Down Expand Up @@ -178,7 +179,7 @@ async def before_serving():
"""
) #TODO Transform into logging

except ServerError as e:
except (ServerError, inference.ModelAPIErrors) as e:
print(e)
raise

Expand Down Expand Up @@ -207,15 +208,15 @@ async def delete_directory():
container_client.delete_blob(blob.name)
return jsonify([True]), 200
else:
return jsonify(["directory does not exist"]), 400
raise DeleteDirectoryRequestError("directory does not exist")
else:
return jsonify(["failed to mount container"]), 400
raise DeleteDirectoryRequestError("failed to mount container")
else:
return jsonify(["missing container or directory name"]), 400
raise DeleteDirectoryRequestError("missing container or directory name")

except DeleteDirectoryRequestError as error:
except (KeyError, TypeError, azure_storage_api.MountContainerError, ResourceNotFoundError, DeleteDirectoryRequestError, ServiceResponseError) as error:
print(error)
return jsonify(["DeleteDirectoryRequestError: " + str(error)]), 400
return jsonify([f"DeleteDirectoryRequestError: {str(error)}"]), 400


@app.post("/dir")
Expand All @@ -233,11 +234,11 @@ async def list_directories():
response = await azure_storage_api.get_directories(container_client)
return jsonify(response), 200
else:
return jsonify(["Missing container name"]), 400
raise ListDirectoriesRequestError("Missing container name")

except ListDirectoriesRequestError as error:
except (KeyError, TypeError, ListDirectoriesRequestError, azure_storage_api.MountContainerError) as error:
print(error)
return jsonify(["ListDirectoriesRequestError: " + str(error)]), 400
return jsonify([f"ListDirectoriesRequestError: {str(error)}"]), 400


@app.post("/create-dir")
Expand All @@ -259,13 +260,13 @@ async def create_directory():
if response:
return jsonify([True]), 200
else:
return jsonify(["directory already exists"]), 400
raise CreateDirectoryRequestError("directory already exists")
else:
return jsonify(["missing container or directory name"]), 400
raise CreateDirectoryRequestError("missing container or directory name")

except CreateDirectoryRequestError as error:
except (KeyError, TypeError, CreateDirectoryRequestError, azure_storage_api.MountContainerError) as error:
print(error)
return jsonify(["CreateDirectoryRequestError: " + str(error)]), 400
return jsonify([f"CreateDirectoryRequestError: {str(error)}"]), 400


@app.post("/image-validation")
Expand Down Expand Up @@ -318,9 +319,9 @@ async def image_validation():

return jsonify([validator]), 200

except (UnidentifiedImageError, ImageValidationError) as error:
except (KeyError, TypeError, ValueError, ImageValidationError) as error:
print(error)
return jsonify([error.args[0]]), 400
return jsonify([f"ImageValidationError: {str(error)}"]), 400


@app.post("/inf")
Expand Down Expand Up @@ -403,7 +404,7 @@ async def inference_request():
print(f"Took: {'{:10.4f}'.format(time.perf_counter() - seconds)} seconds") # TODO: Transform into logging
return jsonify(processed_result_json), 200

except (KeyError, InferenceRequestError) as error:
except (inference.ModelAPIErrors, KeyError, TypeError, ValueError, InferenceRequestError, azure_storage_api.MountContainerError) as error:
print(error)
return jsonify(["InferenceRequestError: " + error.args[0]]), 400

Expand Down
6 changes: 0 additions & 6 deletions custom_exceptions.py

This file was deleted.

4 changes: 3 additions & 1 deletion model/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Contains the colors palettes uses to colors the boxes.
"""
import numpy as np
from typing import Union


# Find color by name or hex code: https://www.color-name.com

Expand Down Expand Up @@ -74,7 +76,7 @@ def mixing_palettes(dict1: dict, dict2: dict) -> dict:

return {key: dict1[key] + dict2[key] for key in dict1.keys()}

def shades_colors(base_color: str | tuple, num_shades = 5, lighten_factor = 0.15, darken_factor = 0.1) -> tuple:
def shades_colors(base_color: Union[str, tuple], num_shades = 5, lighten_factor = 0.15, darken_factor = 0.1) -> tuple:
"""
Generate shades of a color based on the base color.
Expand Down
10 changes: 6 additions & 4 deletions model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

import numpy as np

from custom_exceptions import ProcessInferenceResultError
from model.color_palette import primary_colors, light_colors, mixing_palettes, shades_colors
from model.model_exceptions import ModelAPIErrors

class ProcessInferenceResultsModelAPIError(ModelAPIErrors) :
pass

def generator(list_length):
for i in range(list_length):
Expand All @@ -21,7 +23,7 @@ def generator(list_length):

async def process_inference_results(
data: dict,
imageDims: list[int, int],
imageDims: 'list[int, int]',
area_ratio: float = 0.5,
color_format: str = "hex"
) -> dict:
Expand Down Expand Up @@ -138,6 +140,6 @@ async def process_inference_results(

return data

except (KeyError, TypeError, IndexError) as error:
except (KeyError, TypeError, IndexError, ValueError, ZeroDivisionError) as error:
print(error)
raise ProcessInferenceResultError("Error processing inference results") from error
raise ProcessInferenceResultsModelAPIError(f"Error while processing inference results :\n {str(error)}") from error
2 changes: 2 additions & 0 deletions model/model_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ModelAPIErrors(Exception):
pass
15 changes: 10 additions & 5 deletions model/seed_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
import io
import base64
import json
from urllib.error import URLError

from PIL import Image
from collections import namedtuple
from urllib.request import Request, urlopen, HTTPError
from custom_exceptions import ProcessInferenceResultError
from urllib.request import Request, urlopen
from model.model_exceptions import ModelAPIErrors

class SeedDetectorModelAPIError(ModelAPIErrors) :
pass

def process_image_slicing(image_bytes: bytes, result_json: dict) -> list:
"""
Expand Down Expand Up @@ -69,7 +73,7 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul
dict: A dictionary containing the result JSON and the images generated from the inference.
Raises:
InferenceRequestError: If an error occurs while processing the request.
ProcessInferenceResultsError: If an error occurs while processing the request.
"""
try:

Expand Down Expand Up @@ -99,5 +103,6 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul
"result_json": result_object,
"images": process_image_slicing(previous_result, result_object)
}
except HTTPError as e:
raise ProcessInferenceResultError(f"An error occurred while processing the request:\n {str(e)}") from None
except (KeyError, TypeError, IndexError, ValueError, URLError, json.JSONDecodeError) as error:
print(error)
raise SeedDetectorModelAPIError(f"Error while processing inference results :\n {str(error)}") from error
15 changes: 10 additions & 5 deletions model/six_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

import json
from collections import namedtuple
from urllib.request import Request, urlopen, HTTPError
from custom_exceptions import ProcessInferenceResultError
from urllib.error import URLError
from urllib.request import Request, urlopen
from model.model_exceptions import ModelAPIErrors

class SixSeedModelAPIError(ModelAPIErrors) :
pass

async def request_inference_from_nachet_6seeds(model: namedtuple, previous_result: str):
"""
Expand All @@ -20,7 +24,7 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul
dict: The result of the inference as a JSON object.
Raises:
InferenceRequestError: If an error occurs while processing the request.
ProcessInferenceResultsError: If an error occurs while processing the request.
"""
try:
headers = {
Expand All @@ -47,5 +51,6 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul

return result_object

except HTTPError as e:
raise ProcessInferenceResultError(f"An error occurred while processing the request:\n {str(e)}") from None
except (KeyError, TypeError, IndexError, URLError, json.JSONDecodeError) as error:
print(error)
raise SixSeedModelAPIError(f"Error while processing inference results :\n {str(error)}") from error
16 changes: 10 additions & 6 deletions model/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import json

from collections import namedtuple
from urllib.request import Request, urlopen, HTTPError
from custom_exceptions import ProcessInferenceResultError
from urllib.error import URLError
from urllib.request import Request, urlopen
from model.model_exceptions import ModelAPIErrors

class SwinModelAPIError(ModelAPIErrors) :
pass

def process_swin_result(img_box:dict, results: dict) -> list:
"""
Expand All @@ -27,7 +30,7 @@ def process_swin_result(img_box:dict, results: dict) -> list:
return img_box


async def request_inference_from_swin(model: namedtuple, previous_result: list[bytes]):
async def request_inference_from_swin(model: namedtuple, previous_result: 'list[bytes]'):
"""
Perform inference using the SWIN model on a list of images.
Expand All @@ -39,7 +42,7 @@ async def request_inference_from_swin(model: namedtuple, previous_result: list[b
The result of the inference.
Raises:
InferenceRequestError: If an error occurs while processing the request.
ProcessInferenceResultsError: If an error occurs while processing the request.
"""
try:
results = []
Expand All @@ -58,5 +61,6 @@ async def request_inference_from_swin(model: namedtuple, previous_result: list[b
print(json.dumps(results, indent=4)) #TODO Transform into logging

return process_swin_result(previous_result.get("result_json"), results)
except HTTPError as e:
raise ProcessInferenceResultError(f"An error occurred while processing the request:\n {str(e)}") from None
except (TypeError, IndexError, AttributeError, URLError, json.JSONDecodeError) as error:
print(error)
raise SwinModelAPIError(f"An error occurred while processing the request:\n {str(error)}") from error
9 changes: 6 additions & 3 deletions model/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
request_inference_from_nachet_six_seed: Requests inference from the Nachet Six Seed model.
"""
from collections import namedtuple
from custom_exceptions import ProcessInferenceResultError
from model.model_exceptions import ModelAPIErrors

class TestModelAPIError(ModelAPIErrors) :
pass

async def request_inference_from_test(model: namedtuple, previous_result: str):
"""
Expand All @@ -22,7 +24,7 @@ async def request_inference_from_test(model: namedtuple, previous_result: str):
dict: The result of the inference as a JSON object.
Raises:
InferenceRequestError: If an error occurs while processing the request.
ProcessInferenceResultsError: If an error occurs while processing the request.
"""
try:
if previous_result == '':
Expand Down Expand Up @@ -53,4 +55,5 @@ async def request_inference_from_test(model: namedtuple, previous_result: str):
]

except ValueError as error:
raise ProcessInferenceResultError("An error occurred while processing the request") from error
print(error)
raise TestModelAPIError(f"An error occurred while processing the requests :\n {str(error)}") from error
6 changes: 6 additions & 0 deletions renovate.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended"
]
}
8 changes: 4 additions & 4 deletions tests/test_image_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_invalid_header(self):
data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file header: data:image/')
self.assertEqual(data[0], 'ImageValidationError: invalid file header: data:image/')

@patch("magic.Magic.from_buffer")
def test_invalid_extension(self, mock_magic_from_buffer):
Expand All @@ -71,7 +71,7 @@ def test_invalid_extension(self, mock_magic_from_buffer):
data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file extension: plain')
self.assertEqual(data[0], 'ImageValidationError: invalid file extension: plain')

@patch("PIL.Image.open")
def test_invalid_size(self, mock_open):
Expand All @@ -96,7 +96,7 @@ def test_invalid_size(self, mock_open):
data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file size: 2000x2000')
self.assertEqual(data[0], 'ImageValidationError: invalid file size: 2000x2000')

@patch("PIL.Image.open")
def test_resizable_error(self, mock_open):
Expand All @@ -122,7 +122,7 @@ def test_resizable_error(self, mock_open):
data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file not resizable')
self.assertEqual(data[0], 'ImageValidationError: invalid file not resizable')

if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions tests/test_overlapping_in_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
light_colors,
mixing_palettes,
shades_colors,
ProcessInferenceResultError
ProcessInferenceResultsModelAPIError,
)

class TestInferenceProcessFunction(unittest.TestCase):
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_process_inference_error(self):
"totalBoxes": 2
}

with self.assertRaises(ProcessInferenceResultError):
with self.assertRaises(ProcessInferenceResultsModelAPIError):
asyncio.run(
process_inference_results(data=[data], imageDims=[100, 100]))

Expand All @@ -116,14 +116,14 @@ def test_process_inference_error(self):
"totalBoxes": 2
}

with self.assertRaises(ProcessInferenceResultError):
with self.assertRaises(ProcessInferenceResultsModelAPIError):
asyncio.run(process_inference_results(data=[data], imageDims=100))

data ={
"boxes": None,
"totalBoxes": 2
}

with self.assertRaises(ProcessInferenceResultError):
with self.assertRaises(ProcessInferenceResultsModelAPIError):
asyncio.run(
process_inference_results(data=[data], imageDims=[100, 100]))

0 comments on commit c30f7e1

Please sign in to comment.