Skip to content

Commit

Permalink
fixes #57: add unittest for image validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxence Guindon committed Mar 20, 2024
1 parent 34e294e commit 7f81fae
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 80 deletions.
58 changes: 30 additions & 28 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,26 @@

CACHE = {
'seeds': None,
'endpoints': None
'endpoints': None,
'validators': []
}

# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")
# if connection_string is None:
# raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if endpoint_url is None:
raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")
# if endpoint_url is None:
# raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")

if endpoint_api_key is None:
raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")
# if endpoint_api_key is None:
# raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")
# # Check: are environment variables correct?
# if not bool(re.match(connection_string_regex, connection_string)):
# raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(endpoint_url_regex, endpoint_url)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")
# if not bool(re.match(endpoint_url_regex, endpoint_url)):
# raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

app = Quart(__name__)
app = cors(app, allow_origin="*", allow_methods=["GET", "POST", "OPTIONS"])
Expand Down Expand Up @@ -148,10 +149,10 @@ async def create_directory():
async def image_validation():
"""
Validates an image based on its extension, header, size, and resizability.
Returns:
A JSON response containing a boolean indicating whether the image is valid and a validator hash.
Raises:
ImageValidationError: If the image fails any of the validation checks.
"""
Expand All @@ -170,17 +171,17 @@ async def image_validation():
# extension check
if image_extension not in valide_extension:
raise ImageValidationError("Invalid file extension")

expected_header = f"data:image/{image_extension};base64"

# header check
if header.lower() != expected_header:
raise ImageValidationError("Invalid file header")

# size check
if image.size[0] > valide_dimension[0] and image.size[1] > valide_dimension[1]:
raise ImageValidationError("Invalid file size")

# resizable check
try:
size = (100,150)
Expand All @@ -189,14 +190,16 @@ async def image_validation():
raise ImageValidationError("Invalid file not resizable")

validator = await azure_storage_api.generate_hash(base64.b64decode(encoded_image))
return jsonify([True, validator]), 200

CACHE['validators'].append(validator)

return jsonify([validator]), 200

except ImageValidationError as error:
print(error)
return jsonify([False, error.message]), 400
return jsonify([error.args[0]]), 400
except Exception as error:
print(error)
return jsonify([False, "ImageValidationError: " + str(error)]), 400
return jsonify([f"ImageValidationError: {error.args[0]}"]), 400


@app.post("/inf")
Expand Down Expand Up @@ -279,11 +282,11 @@ async def get_seed_data(seed_name):
"""
Returns JSON containing requested seed data
"""
if seed_name in CACHE['seeds']:
if seed_name in CACHE['seeds']:
return jsonify(CACHE['seeds'][seed_name]), 200
else:
return jsonify(f"No information found for {seed_name}."), 400


@app.get("/reload-seed-data")
async def reload_seed_data():
Expand All @@ -302,7 +305,7 @@ async def get_model_endpoints_metadata():
"""
Returns JSON containing the deployed endpoints' metadata
"""
if CACHE['endpoints']:
if CACHE['endpoints']:
return jsonify(CACHE['endpoints']), 200
else:
return jsonify("Error retrieving model endpoints metadata.", 400)
Expand All @@ -312,7 +315,7 @@ async def get_model_endpoints_metadata():
async def health():
return "ok", 200


async def fetch_json(repo_URL, key, file_path):
"""
Fetches JSON document from a GitHub repository and caches it
Expand All @@ -328,7 +331,7 @@ async def fetch_json(repo_URL, key, file_path):
HTTP Status Code: {error.code}"}), 400
except Exception as e:
return jsonify({"error": str(e)}), 500


@app.before_serving
async def before_serving():
Expand All @@ -338,4 +341,3 @@ async def before_serving():

if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=8080)

4 changes: 1 addition & 3 deletions custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,4 @@ class ServerError(Exception):
pass

class ImageValidationError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
pass
157 changes: 108 additions & 49 deletions tests/test_image_validation.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,135 @@
import unittest
import base64
import json
import asyncio

from app import app
from io import BytesIO
from PIL import Image
from unittest import TestCase, main
import requests

"""
In order for the tests to run, the server must be running.
TO DO - Create a mock server to run the tests without the need for the server to be running.
TO DO - Implement test_image for every other checks (size, format, resizable)
TO DO - Implement test_image for every type of image (PNG, JPEG, GIF, BMP, TIFF, WEBP, SVG)
TO DO -
"""

class test_image_validation(TestCase):
# V1 with server running
def test_real_image_validation(self):
image = Image.new('RGB', (150, 150), 'blue')
from unittest.mock import patch, Mock

# Save the image to a byte array
img_byte_array = BytesIO()

image_header = "data:image/PNG;base64,"
class TestImageValidation(unittest.TestCase):
def setUp(self):
self.test_client = app.test_client()

image.save(img_byte_array, 'PNG')
self.img_byte_array = BytesIO()
image = Image.new('RGB', (150, 150), 'blue')
self.image_header = "data:image/PNG;base64,"
image.save(self.img_byte_array, 'PNG')

data = base64.b64encode(img_byte_array.getvalue()).decode('utf-8')
def test_real_image_validation(self):
data = base64.b64encode(self.img_byte_array.getvalue())
data = data.decode('utf-8')

response = requests.post(
url="http://0.0.0.0:8080/image-validation",
data= str.encode(json.dumps({'image': image_header + data})),
headers={
response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
}
)
data = json.loads(response.content)
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

if isinstance(data[1], str):
self.assertEqual(response.status_code, 200)
self.assertEqual(data[0], True)
else:
self.assertEqual(response.status_code, 200)
data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 200)
self.assertIsInstance(data[0], str)

# v2 with server not running
def test_invalid_header_image_validation(self):
image = Image.new('RGB', (150, 150), 'blue')
data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image':"data:image/," + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'Invalid file header')

@patch("PIL.Image.open")
def test_invalid_extension(self, mock_open):

mock_image = Mock()
mock_image.format = "md"

mock_open.return_value = mock_image

# Save the image to a byte array
img_byte_array = BytesIO()
data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'Invalid file extension')

@patch("PIL.Image.open")
def test_invalid_size(self, mock_open):
mock_image = Mock()
mock_image.size = [2000, 2000]
mock_image.format = "PNG"

mock_open.return_value = mock_image

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'Invalid file size')

image_header = "data:image/,"
@patch("PIL.Image.open")
def test_rezisable_error(self, mock_open):
mock_image = Mock()
mock_image.size = [1080, 1080]
mock_image.format = "PNG"
mock_image.thumbnail.side_effect = IOError("error can't resize")

image.save(img_byte_array, 'PNG')
mock_open.return_value = mock_image

data = base64.b64encode(img_byte_array.getvalue()).decode('utf-8')
data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = requests.post(
url="http://0.0.0.0:8080/image-validation",
data= str.encode(json.dumps({'image': image_header + data})),
headers={
response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
}
)
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(response.content)
data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], False)
self.assertEqual(data[1], 'Invalid file header')
self.assertEqual(data[0], 'Invalid file not resizable')

if __name__ == '__main__':
main()
unittest.main()

0 comments on commit 7f81fae

Please sign in to comment.