Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue #57: Implement image validation in backend #58

Merged
merged 24 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 94 additions & 24 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
import base64
import re
import io
import magic
from PIL import Image, UnidentifiedImageError
from dotenv import load_dotenv
from quart import Quart, request, jsonify
from quart_cors import cors
Expand All @@ -14,6 +17,7 @@
InferenceRequestError,
CreateDirectoryRequestError,
ServerError,
ImageValidationError,
)

load_dotenv()
Expand All @@ -28,31 +32,19 @@
NACHET_DATA = os.getenv("NACHET_DATA")
NACHET_MODEL = os.getenv("NACHET_MODEL")

VALID_EXTENSION = {"jpeg", "jpg", "png", "gif", "bmp", "tiff", "webp"}
VALID_DIMENSION = [1920, 1080]
rngadam marked this conversation as resolved.
Show resolved Hide resolved

CACHE = {
'seeds': None,
'endpoints': None
'endpoints': None,
'validators': []
}

# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if endpoint_url is None:
raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")

if endpoint_api_key is None:
raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(endpoint_url_regex, endpoint_url)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

app = Quart(__name__)
app = cors(app, allow_origin="*", allow_methods=["GET", "POST", "OPTIONS"])


@app.post("/del")
async def delete_directory():
"""
Expand Down Expand Up @@ -141,6 +133,66 @@ async def create_directory():
return jsonify(["CreateDirectoryRequestError: " + str(error)]), 400


@app.post("/image-validation")
async def image_validation():
"""
Validates an image based on its extension, header, size, and resizability.

Returns:
A JSON response containing a validator hash.

Raises:
ImageValidationError: If the image fails any of the validation checks.
"""
try:

data = await request.get_json()
image_base64 = data["image"]

header, encoded_image = image_base64.split(",", 1)

image_bytes = base64.b64decode(encoded_image)

image = Image.open(io.BytesIO(image_bytes))
image_extension = image.format.lower()

# extension check
if image_extension not in VALID_EXTENSION:
rngadam marked this conversation as resolved.
Show resolved Hide resolved
raise ImageValidationError(f"invalid file extension: {image_extension}")

expected_header = f"data:image/{image_extension};base64"
expected_magic_header =f"image/{image_extension}"

# magic header check
magic_header = magic.from_buffer(image_bytes, mime=True)
if magic_header != expected_magic_header:
raise ImageValidationError(f"invalid file header: {magic_header}")

# header check
if header.lower() != expected_header:
raise ImageValidationError(f"invalid file header: {header}")

# size check
if image.size[0] > VALID_DIMENSION[0] and image.size[1] > VALID_DIMENSION[1]:
rngadam marked this conversation as resolved.
Show resolved Hide resolved
raise ImageValidationError(f"invalid file size: {image.size[0]}x{image.size[1]}")

# resizable check
try:
size = (100,150)
image.thumbnail(size)
except IOError:
raise ImageValidationError("invalid file not resizable")

validator = await azure_storage_api.generate_hash(image_bytes)
CACHE['validators'].append(validator)

return jsonify([validator]), 200

except (FileNotFoundError, ValueError, TypeError, UnidentifiedImageError, ImageValidationError) as error:
rngadam marked this conversation as resolved.
Show resolved Hide resolved
print(error)
k-allagbe marked this conversation as resolved.
Show resolved Hide resolved
return jsonify([error.args[0]]), 400


@app.post("/inf")
async def inference_request():
"""
Expand All @@ -154,6 +206,7 @@ async def inference_request():
container_name = data["container_name"]
imageDims = data["imageDims"]
image_base64 = data["image"]

if folder_name and container_name and imageDims and image_base64:
header, encoded_data = image_base64.split(",", 1)
image_bytes = base64.b64decode(encoded_data)
Expand Down Expand Up @@ -220,11 +273,11 @@ async def get_seed_data(seed_name):
"""
Returns JSON containing requested seed data
"""
if seed_name in CACHE['seeds']:
if seed_name in CACHE['seeds']:
return jsonify(CACHE['seeds'][seed_name]), 200
else:
return jsonify(f"No information found for {seed_name}."), 400


@app.get("/reload-seed-data")
async def reload_seed_data():
Expand All @@ -243,7 +296,7 @@ async def get_model_endpoints_metadata():
"""
Returns JSON containing the deployed endpoints' metadata
"""
if CACHE['endpoints']:
if CACHE['endpoints']:
return jsonify(CACHE['endpoints']), 200
else:
return jsonify("Error retrieving model endpoints metadata.", 400)
Expand All @@ -253,7 +306,7 @@ async def get_model_endpoints_metadata():
async def health():
return "ok", 200


async def fetch_json(repo_URL, key, file_path):
"""
Fetches JSON document from a GitHub repository and caches it
Expand All @@ -269,14 +322,31 @@ async def fetch_json(repo_URL, key, file_path):
HTTP Status Code: {error.code}"}), 400
except Exception as e:
return jsonify({"error": str(e)}), 500


@app.before_serving
async def before_serving():
# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if endpoint_url is None:
raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")

if endpoint_api_key is None:
raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(endpoint_url_regex, endpoint_url)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")


await fetch_json(NACHET_DATA, 'seeds', "seeds/all.json")
await fetch_json(NACHET_MODEL, 'endpoints', 'model_endpoints_metadata.json')


if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=8080)

3 changes: 3 additions & 0 deletions custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ class ValidateEnvVariablesError(Exception):

class ServerError(Exception):
pass

class ImageValidationError(Exception):
pass
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ quart
quart-cors
python-dotenv
hypercorn
Pillow
python-magic
153 changes: 153 additions & 0 deletions tests/test_image_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import unittest
import asyncio

from app import app, json, base64, Image, io
from unittest.mock import patch, Mock


class TestImageValidation(unittest.TestCase):
def setUp(self):
self.test_client = app.test_client()

self.img_byte_array = io.BytesIO()
image = Image.new('RGB', (150, 150), 'blue')
self.image_header = "data:image/PNG;base64,"
image.save(self.img_byte_array, 'PNG')

def test_real_image_validation(self):
data = base64.b64encode(self.img_byte_array.getvalue())
data = data.decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 200)
self.assertIsInstance(data[0], str)

def test_invalid_header(self):
data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image':"data:image/," + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file header: data:image/')

@patch("magic.Magic.from_buffer")
def test_invalid_magic_header(self, mock_magic_from_buffer):

mock_magic_from_buffer.return_value = "text/plain"

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file header: text/plain')

@patch("PIL.Image.open")
def test_invalid_extension(self, mock_open):

mock_image = Mock()
mock_image.format = "md"

mock_open.return_value = mock_image

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file extension: md')

@patch("PIL.Image.open")
def test_invalid_size(self, mock_open):
mock_image = Mock()
mock_image.size = [2000, 2000]
mock_image.format = "PNG"

mock_open.return_value = mock_image

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file size: 2000x2000')

@patch("PIL.Image.open")
def test_rezisable_error(self, mock_open):
rngadam marked this conversation as resolved.
Show resolved Hide resolved
mock_image = Mock()
mock_image.size = [1080, 1080]
mock_image.format = "PNG"
mock_image.thumbnail.side_effect = IOError("error can't resize")

mock_open.return_value = mock_image

data = base64.b64encode(self.img_byte_array.getvalue()).decode('utf-8')

response = asyncio.run(
self.test_client.post(
'/image-validation',
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
data= str.encode(json.dumps({'image': self.image_header + data})),
))

data = json.loads(asyncio.run(response.get_data()))

self.assertEqual(response.status_code, 400)
self.assertEqual(data[0], 'invalid file not resizable')

if __name__ == '__main__':
unittest.main()
Loading