Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue #57: Implement image validation in backend #58

Merged
merged 24 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 94 additions & 19 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
import base64
import re
import io
import magic
from PIL import Image, UnidentifiedImageError
from dotenv import load_dotenv
from quart import Quart, request, jsonify
from quart_cors import cors
Expand All @@ -14,6 +17,7 @@
InferenceRequestError,
CreateDirectoryRequestError,
ServerError,
ImageValidationError,
)

load_dotenv()
Expand All @@ -28,34 +32,31 @@
NACHET_DATA = os.getenv("NACHET_DATA")
NACHET_MODEL = os.getenv("NACHET_MODEL")

try:
VALID_EXTENSION = json.loads(os.getenv("NACHET_VALID_EXTENSION"))
VALID_DIMENSION = json.loads(os.getenv("NACHET_VALID_DIMENSION"))
except TypeError:
# For testing
VALID_DIMENSION = {"width": 1920, "height": 1080}
VALID_EXTENSION = {"jpeg", "jpg", "png", "gif", "bmp", "tiff", "webp"}

CACHE = {
'seeds': None,
'endpoints': None
'endpoints': None,
'validators': []
}

# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if endpoint_url is None:
raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")

if endpoint_api_key is None:
raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(endpoint_url_regex, endpoint_url)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

app = Quart(__name__)
app = cors(app, allow_origin="*", allow_methods=["GET", "POST", "OPTIONS"])
mb = int(os.getenv("NACHET_MAX_CONTENT_LENGTH"))

try:
mb = int(os.getenv("NACHET_MAX_CONTENT_LENGTH"))
rngadam marked this conversation as resolved.
Show resolved Hide resolved
rngadam marked this conversation as resolved.
Show resolved Hide resolved
except TypeError:
mb = 16

app.config["MAX_CONTENT_LENGTH"] = mb * 1024 * 1024


@app.post("/del")
async def delete_directory():
"""
Expand Down Expand Up @@ -144,6 +145,61 @@ async def create_directory():
return jsonify(["CreateDirectoryRequestError: " + str(error)]), 400


@app.post("/image-validation")
async def image_validation():
"""
Validates an image based on its extension, header, size, and resizability.

Returns:
A JSON response containing a validator hash.

Raises:
ImageValidationError: If the image fails any of the validation checks.
"""
try:

data = await request.get_json()
image_base64 = data["image"]

header, encoded_image = image_base64.split(",", 1)

image_bytes = base64.b64decode(encoded_image)
image = Image.open(io.BytesIO(image_bytes))

magic_header = magic.from_buffer(image_bytes, mime=True)
image_extension = magic_header.split("/")[1]

# extension check
if image_extension not in VALID_EXTENSION:
rngadam marked this conversation as resolved.
Show resolved Hide resolved
raise ImageValidationError(f"invalid file extension: {image_extension}")

expected_header = f"data:image/{image_extension};base64"

# header check
if header.lower() != expected_header:
raise ImageValidationError(f"invalid file header: {header}")

# size check
if image.size[0] > VALID_DIMENSION["width"] and image.size[1] > VALID_DIMENSION["height"]:
raise ImageValidationError(f"invalid file size: {image.size[0]}x{image.size[1]}")

# resizable check
try:
size = (100,150)
image.thumbnail(size)
except IOError:
raise ImageValidationError("invalid file not resizable")

validator = await azure_storage_api.generate_hash(image_bytes)
CACHE['validators'].append(validator)

return jsonify([validator]), 200

except (FileNotFoundError, ValueError, TypeError, UnidentifiedImageError, ImageValidationError) as error:
rngadam marked this conversation as resolved.
Show resolved Hide resolved
print(error)
k-allagbe marked this conversation as resolved.
Show resolved Hide resolved
return jsonify([error.args[0]]), 400


@app.post("/inf")
async def inference_request():
"""
Expand All @@ -157,6 +213,7 @@ async def inference_request():
container_name = data["container_name"]
imageDims = data["imageDims"]
image_base64 = data["image"]

if folder_name and container_name and imageDims and image_base64:
header, encoded_data = image_base64.split(",", 1)
image_bytes = base64.b64decode(encoded_data)
Expand Down Expand Up @@ -276,6 +333,24 @@ async def fetch_json(repo_URL, key, file_path):

@app.before_serving
async def before_serving():
# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if endpoint_url is None:
raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")

if endpoint_api_key is None:
raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(endpoint_url_regex, endpoint_url)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")


await fetch_json(NACHET_DATA, 'seeds', "seeds/all.json")
await fetch_json(NACHET_MODEL, 'endpoints', 'model_endpoints_metadata.json')

Expand Down
3 changes: 3 additions & 0 deletions custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ class ValidateEnvVariablesError(Exception):

class ServerError(Exception):
pass

class ImageValidationError(Exception):
pass
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ quart
quart-cors
python-dotenv
hypercorn
Pillow
python-magic
128 changes: 128 additions & 0 deletions tests/test_image_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import unittest
import asyncio

from app import app, json, base64, Image, io
from unittest.mock import patch, Mock


class TestImageValidation(unittest.TestCase):
def setUp(self):
self.test_client = app.test_client()

self.img_byte_array = io.BytesIO()
image = Image.new('RGB', (150, 150), 'blue')
self.image_header = "data:image/PNG;base64,"
image.save(self.img_byte_array, 'PNG')

def test_real_image_validation(self):
data = base64.b64encode(self.img_byte_array.getvalue())
data = data.decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 200)
self.assertIsInstance(data[0], str)

def test_invalid_header(self):
data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image':"data:image/," + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file header: data:image/')

@patch("magic.Magic.from_buffer")
def test_invalid_extension(self, mock_magic_from_buffer):

mock_magic_from_buffer.return_value = "text/plain"

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file extension: plain')

@patch("PIL.Image.open")
def test_invalid_size(self, mock_open):
mock_image = Mock()
mock_image.size = [2000, 2000]
mock_image.format = "PNG"

mock_open.return_value = mock_image

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file size: 2000x2000')

@patch("PIL.Image.open")
def test_rezisable_error(self, mock_open):
rngadam marked this conversation as resolved.
Show resolved Hide resolved
mock_image = Mock()
mock_image.size = [1080, 1080]
mock_image.format = "PNG"
mock_image.thumbnail.side_effect = IOError("error can't resize")

mock_open.return_value = mock_image

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file not resizable')

if __name__ == '__main__':
unittest.main()
Loading