Skip to content

Commit

Permalink
fixes #57: Implement image-validation endpoint and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxence Guindon committed Feb 20, 2024
1 parent 045831b commit 34e294e
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 0 deletions.
59 changes: 59 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import base64
import re
import io
from PIL import Image
from dotenv import load_dotenv
from quart import Quart, request, jsonify
from quart_cors import cors
Expand All @@ -14,6 +16,7 @@
InferenceRequestError,
CreateDirectoryRequestError,
ServerError,
ImageValidationError,
)

load_dotenv()
Expand Down Expand Up @@ -141,6 +144,61 @@ async def create_directory():
return jsonify(["CreateDirectoryRequestError: " + str(error)]), 400


@app.post("/image-validation")
async def image_validation():
"""
Validates an image based on its extension, header, size, and resizability.
Returns:
A JSON response containing a boolean indicating whether the image is valid and a validator hash.
Raises:
ImageValidationError: If the image fails any of the validation checks.
"""
try:
valide_extension = {"jpeg", "jpg", "png", "gif", "bmp", "tiff", "webp"}
valide_dimension = [1920, 1080]

data = await request.get_json()
image_base64 = data["image"]

header, encoded_image = image_base64.split(",", 1)

image = Image.open(io.BytesIO(base64.b64decode(encoded_image)))
image_extension = image.format.lower()

# extension check
if image_extension not in valide_extension:
raise ImageValidationError("Invalid file extension")

expected_header = f"data:image/{image_extension};base64"

# header check
if header.lower() != expected_header:
raise ImageValidationError("Invalid file header")

# size check
if image.size[0] > valide_dimension[0] and image.size[1] > valide_dimension[1]:
raise ImageValidationError("Invalid file size")

# resizable check
try:
size = (100,150)
image.thumbnail(size)
except IOError:
raise ImageValidationError("Invalid file not resizable")

validator = await azure_storage_api.generate_hash(base64.b64decode(encoded_image))
return jsonify([True, validator]), 200

except ImageValidationError as error:
print(error)
return jsonify([False, error.message]), 400
except Exception as error:
print(error)
return jsonify([False, "ImageValidationError: " + str(error)]), 400


@app.post("/inf")
async def inference_request():
"""
Expand All @@ -154,6 +212,7 @@ async def inference_request():
container_name = data["container_name"]
imageDims = data["imageDims"]
image_base64 = data["image"]

if folder_name and container_name and imageDims and image_base64:
header, encoded_data = image_base64.split(",", 1)
image_bytes = base64.b64decode(encoded_data)
Expand Down
5 changes: 5 additions & 0 deletions custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ class ValidateEnvVariablesError(Exception):

class ServerError(Exception):
pass

class ImageValidationError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ quart
quart-cors
python-dotenv
hypercorn
Pillow
76 changes: 76 additions & 0 deletions tests/test_image_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import base64
import json

from io import BytesIO
from PIL import Image
from unittest import TestCase, main
import requests

"""
In order for the tests to run, the server must be running.
TO DO - Create a mock server to run the tests without the need for the server to be running.
TO DO - Implement test_image for every other checks (size, format, resizable)
TO DO - Implement test_image for every type of image (PNG, JPEG, GIF, BMP, TIFF, WEBP, SVG)
TO DO -
"""

class test_image_validation(TestCase):
# V1 with server running
def test_real_image_validation(self):
image = Image.new('RGB', (150, 150), 'blue')

# Save the image to a byte array
img_byte_array = BytesIO()

image_header = "data:image/PNG;base64,"

image.save(img_byte_array, 'PNG')

data = base64.b64encode(img_byte_array.getvalue()).decode('utf-8')

response = requests.post(
url="http://0.0.0.0:8080/image-validation",
data= str.encode(json.dumps({'image': image_header + data})),
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
}
)
data = json.loads(response.content)

if isinstance(data[1], str):
self.assertEqual(response.status_code, 200)
self.assertEqual(data[0], True)
else:
self.assertEqual(response.status_code, 200)

# v2 with server not running
def test_invalid_header_image_validation(self):
image = Image.new('RGB', (150, 150), 'blue')

# Save the image to a byte array
img_byte_array = BytesIO()

image_header = "data:image/,"

image.save(img_byte_array, 'PNG')

data = base64.b64encode(img_byte_array.getvalue()).decode('utf-8')

response = requests.post(
url="http://0.0.0.0:8080/image-validation",
data= str.encode(json.dumps({'image': image_header + data})),
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
}
)

data = json.loads(response.content)

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], False)
self.assertEqual(data[1], 'Invalid file header')

if __name__ == '__main__':
main()

0 comments on commit 34e294e

Please sign in to comment.