From 34e294e0cdb2e73b62f5de46888eb43fae9334ec Mon Sep 17 00:00:00 2001 From: Maxence Guindon Date: Tue, 20 Feb 2024 19:59:45 +0000 Subject: [PATCH] fixes #57: Implement image-validation endpoint and test --- app.py | 59 ++++++++++++++++++++++++++ custom_exceptions.py | 5 +++ requirements.txt | 1 + tests/test_image_validation.py | 76 ++++++++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+) create mode 100644 tests/test_image_validation.py diff --git a/app.py b/app.py index 3e486f49..b05c216a 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,8 @@ import os import base64 import re +import io +from PIL import Image from dotenv import load_dotenv from quart import Quart, request, jsonify from quart_cors import cors @@ -14,6 +16,7 @@ InferenceRequestError, CreateDirectoryRequestError, ServerError, + ImageValidationError, ) load_dotenv() @@ -141,6 +144,61 @@ async def create_directory(): return jsonify(["CreateDirectoryRequestError: " + str(error)]), 400 +@app.post("/image-validation") +async def image_validation(): + """ + Validates an image based on its extension, header, size, and resizability. + + Returns: + A JSON response containing a boolean indicating whether the image is valid and a validator hash. + + Raises: + ImageValidationError: If the image fails any of the validation checks. + """ + try: + valide_extension = {"jpeg", "jpg", "png", "gif", "bmp", "tiff", "webp"} + valide_dimension = [1920, 1080] + + data = await request.get_json() + image_base64 = data["image"] + + header, encoded_image = image_base64.split(",", 1) + + image = Image.open(io.BytesIO(base64.b64decode(encoded_image))) + image_extension = image.format.lower() + + # extension check + if image_extension not in valide_extension: + raise ImageValidationError("Invalid file extension") + + expected_header = f"data:image/{image_extension};base64" + + # header check + if header.lower() != expected_header: + raise ImageValidationError("Invalid file header") + + # size check + if image.size[0] > valide_dimension[0] and image.size[1] > valide_dimension[1]: + raise ImageValidationError("Invalid file size") + + # resizable check + try: + size = (100,150) + image.thumbnail(size) + except IOError: + raise ImageValidationError("Invalid file not resizable") + + validator = await azure_storage_api.generate_hash(base64.b64decode(encoded_image)) + return jsonify([True, validator]), 200 + + except ImageValidationError as error: + print(error) + return jsonify([False, error.message]), 400 + except Exception as error: + print(error) + return jsonify([False, "ImageValidationError: " + str(error)]), 400 + + @app.post("/inf") async def inference_request(): """ @@ -154,6 +212,7 @@ async def inference_request(): container_name = data["container_name"] imageDims = data["imageDims"] image_base64 = data["image"] + if folder_name and container_name and imageDims and image_base64: header, encoded_data = image_base64.split(",", 1) image_bytes = base64.b64decode(encoded_data) diff --git a/custom_exceptions.py b/custom_exceptions.py index 7de693e7..eae0a436 100644 --- a/custom_exceptions.py +++ b/custom_exceptions.py @@ -64,3 +64,8 @@ class ValidateEnvVariablesError(Exception): class ServerError(Exception): pass + +class ImageValidationError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) diff --git a/requirements.txt b/requirements.txt index 5ef34973..53eaad67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ quart quart-cors python-dotenv hypercorn +Pillow diff --git a/tests/test_image_validation.py b/tests/test_image_validation.py new file mode 100644 index 00000000..ebda12cd --- /dev/null +++ b/tests/test_image_validation.py @@ -0,0 +1,76 @@ +import base64 +import json + +from io import BytesIO +from PIL import Image +from unittest import TestCase, main +import requests + +""" +In order for the tests to run, the server must be running. +TO DO - Create a mock server to run the tests without the need for the server to be running. +TO DO - Implement test_image for every other checks (size, format, resizable) +TO DO - Implement test_image for every type of image (PNG, JPEG, GIF, BMP, TIFF, WEBP, SVG) +TO DO - +""" + +class test_image_validation(TestCase): +# V1 with server running + def test_real_image_validation(self): + image = Image.new('RGB', (150, 150), 'blue') + + # Save the image to a byte array + img_byte_array = BytesIO() + + image_header = "data:image/PNG;base64," + + image.save(img_byte_array, 'PNG') + + data = base64.b64encode(img_byte_array.getvalue()).decode('utf-8') + + response = requests.post( + url="http://0.0.0.0:8080/image-validation", + data= str.encode(json.dumps({'image': image_header + data})), + headers={ + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*", + } + ) + data = json.loads(response.content) + + if isinstance(data[1], str): + self.assertEqual(response.status_code, 200) + self.assertEqual(data[0], True) + else: + self.assertEqual(response.status_code, 200) + +# v2 with server not running + def test_invalid_header_image_validation(self): + image = Image.new('RGB', (150, 150), 'blue') + + # Save the image to a byte array + img_byte_array = BytesIO() + + image_header = "data:image/," + + image.save(img_byte_array, 'PNG') + + data = base64.b64encode(img_byte_array.getvalue()).decode('utf-8') + + response = requests.post( + url="http://0.0.0.0:8080/image-validation", + data= str.encode(json.dumps({'image': image_header + data})), + headers={ + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*", + } + ) + + data = json.loads(response.content) + + self.assertEqual(response.status_code, 400) + self.assertEqual(data[0], False) + self.assertEqual(data[1], 'Invalid file header') + +if __name__ == '__main__': + main() \ No newline at end of file