Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue #57: Implement image validation in backend #58

Merged
merged 24 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 109 additions & 17 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import os
import base64
import re
import io
import magic
import time
import warnings

import model.inference as inference
from model import request_function

from PIL import Image, UnidentifiedImageError
from datetime import date
from dotenv import load_dotenv
from quart import Quart, request, jsonify
Expand All @@ -15,15 +20,36 @@
from cryptography.fernet import Fernet
import azure_storage.azure_storage_api as azure_storage_api

from custom_exceptions import (
DeleteDirectoryRequestError,
ListDirectoriesRequestError,
InferenceRequestError,
CreateDirectoryRequestError,
ServerError,
PipelineNotFoundError,
ConnectionStringError
)
class APIErrors(Exception):
pass


class DeleteDirectoryRequestError(APIErrors):
pass


class ListDirectoriesRequestError(APIErrors):
pass


class InferenceRequestError(APIErrors):
pass


class CreateDirectoryRequestError(APIErrors):
pass


class ServerError(APIErrors):
pass


class ImageValidationError(APIErrors):
pass


class ImageWarning(UserWarning):
pass

load_dotenv()
connection_string_regex = r"^DefaultEndpointsProtocol=https?;.*;FileEndpoint=https://[a-zA-Z0-9]+\.file\.core\.windows\.net/;$"
Expand All @@ -38,9 +64,17 @@
NACHET_MODEL = os.getenv("NACHET_MODEL")

try:
MAX_CONTENT_LENGTH = int(os.getenv("NACHET_MAX_CONTENT_LENGTH"))
VALID_EXTENSION = json.loads(os.getenv("NACHET_VALID_EXTENSION"))
VALID_DIMENSION = json.loads(os.getenv("NACHET_VALID_DIMENSION"))
except TypeError:
# For testing
VALID_DIMENSION = {"width": 1920, "height": 1080}
VALID_EXTENSION = {"jpeg", "jpg", "png", "gif", "bmp", "tiff", "webp"}

try:
MAX_CONTENT_LENGTH_MEGABYTES = int(os.getenv("NACHET_MAX_CONTENT_LENGTH"))
except (TypeError, ValueError):
MAX_CONTENT_LENGTH = 16
MAX_CONTENT_LENGTH_MEGABYTES = 16


Model = namedtuple(
Expand All @@ -59,11 +93,12 @@
"seeds": None,
"endpoints": None,
"pipelines": {},
"validators": []
}

app = Quart(__name__)
app = cors(app, allow_origin="*", allow_methods=["GET", "POST", "OPTIONS"])
app.config["MAX_CONTENT_LENGTH"] = MAX_CONTENT_LENGTH * 1024 * 1024
app.config["MAX_CONTENT_LENGTH"] = MAX_CONTENT_LENGTH_MEGABYTES * 1024 * 1024


@app.post("/del")
Expand Down Expand Up @@ -151,6 +186,61 @@ async def create_directory():
return jsonify(["CreateDirectoryRequestError: " + str(error)]), 400


@app.post("/image-validation")
async def image_validation():
"""
Validates an image based on its extension, header, size, and resizability.

Returns:
A JSON response containing a validator hash.

Raises:
ImageValidationError: If the image fails any of the validation checks.
"""
try:

data = await request.get_json()
image_base64 = data["image"]

header, encoded_image = image_base64.split(",", 1)

image_bytes = base64.b64decode(encoded_image)
image = Image.open(io.BytesIO(image_bytes))

magic_header = magic.from_buffer(image_bytes, mime=True)
image_extension = magic_header.split("/")[1]

# extension check
if image_extension not in VALID_EXTENSION:
raise ImageValidationError(f"invalid file extension: {image_extension}")

expected_header = f"data:image/{image_extension};base64"

# header check
if header.lower() != expected_header:
raise ImageValidationError(f"invalid file header: {header}")

# size check
if image.size[0] > VALID_DIMENSION["width"] and image.size[1] > VALID_DIMENSION["height"]:
raise ImageValidationError(f"invalid file size: {image.size[0]}x{image.size[1]}")

# resizable check
try:
size = (100,150)
image.thumbnail(size)
except IOError:
raise ImageValidationError("invalid file not resizable")

validator = await azure_storage_api.generate_hash(image_bytes)
CACHE['validators'].append(validator)

return jsonify([validator]), 200

except (FileNotFoundError, ValueError, TypeError, UnidentifiedImageError, ImageValidationError) as error:
print(error)
return jsonify([error.args[0]]), 400


@app.post("/inf")
async def inference_request():
"""
Expand All @@ -163,6 +253,7 @@ async def inference_request():
print(f"{date.today()} Entering inference request") # TODO: Transform into logging
data = await request.get_json()
pipeline_name = data.get("model_name")
validator = data.get("validator")
folder_name = data["folder_name"]
container_name = data["container_name"]
imageDims = data["imageDims"]
Expand All @@ -174,6 +265,7 @@ async def inference_request():
print(f"Requested by user: {container_name}") # TODO: Transform into logging
pipelines_endpoints = CACHE.get("pipelines")
blob_service_client = app.config.get("BLOB_CLIENT")
validators = CACHE.get("validators")

if not (folder_name and container_name and imageDims and image_base64):
raise InferenceRequestError(
Expand All @@ -182,11 +274,11 @@ async def inference_request():
if not pipelines_endpoints.get(pipeline_name):
raise InferenceRequestError(f"model {pipeline_name} not found")

header, encoded_data = image_base64.split(",", 1)
_, encoded_data = image_base64.split(",", 1)

# Validate image header #TODO with magic header
if not header.startswith("data:image/"):
raise InferenceRequestError("invalid image header")
if validator not in validators:
warnings.warn("this picture was not validate", ImageWarning)
# TODO: implement logic when frontend start returning validators

# Keep track of every output given by the models
# TODO: add it to CACHE variable
Expand Down Expand Up @@ -331,7 +423,7 @@ async def get_pipelines():
app.config["BLOB_CLIENT"] = await azure_storage_api.get_blob_client(connection_string)
result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], PIPELINE_BLOB_NAME, PIPELINE_VERSION)
cipher_suite = Fernet(FERNET_KEY)
except (ConnectionStringError, PipelineNotFoundError) as error:
except (azure_storage_api.AzureAPIErrors) as error:
print(error)
raise ServerError("server errror: could not retrieve the pipelines") from error

Expand Down
69 changes: 50 additions & 19 deletions azure_storage/azure_storage_api.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
import json
import uuid
import hashlib
import datetime
from azure.storage.blob import BlobServiceClient, ContainerClient
from azure.core.exceptions import ResourceNotFoundError
from custom_exceptions import (
ConnectionStringError,
MountContainerError,
GetBlobError,
UploadImageError,
UploadInferenceResultError,
GetFolderUUIDError,
FolderListError,
GenerateHashError,
CreateDirectoryError,
PipelineNotFoundError,
)

"""
---- user-container based structure ----- - container name is user id - whenever
a new user is created, a new container is created with the user uuid - inside
Expand All @@ -25,6 +6,56 @@
date, in the container - inside the project folder, there is an image file and a
json file with the image inference results
"""
import json
import uuid
import hashlib
import datetime
from azure.storage.blob import BlobServiceClient, ContainerClient
from azure.core.exceptions import ResourceNotFoundError


class AzureAPIErrors(Exception):
pass


class ConnectionStringError(AzureAPIErrors):
pass


class MountContainerError(AzureAPIErrors):
pass


class GetBlobError(AzureAPIErrors):
pass


class UploadImageError(AzureAPIErrors):
pass


class UploadInferenceResultError(AzureAPIErrors):
pass


class GetFolderUUIDError(AzureAPIErrors):
pass


class FolderListError(AzureAPIErrors):
pass


class GenerateHashError(AzureAPIErrors):
pass


class CreateDirectoryError(AzureAPIErrors):
pass


class PipelineNotFoundError(AzureAPIErrors):
pass


async def generate_hash(image):
Expand Down
60 changes: 0 additions & 60 deletions custom_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,10 @@
class DeleteDirectoryRequestError(Exception):
pass


class ListDirectoriesRequestError(Exception):
pass


class InferenceRequestError(Exception):
pass


class CreateDirectoryRequestError(Exception):
pass


class GenerateHashError(Exception):
pass


class MountContainerError(Exception):
pass


class ContainerNameError(Exception):
pass


class CreateDirectoryError(Exception):
pass


class ConnectionStringError(Exception):
pass


class GetBlobError(Exception):
pass


class UploadImageError(Exception):
pass


class UploadInferenceResultError(Exception):
pass


class GetFolderUUIDError(Exception):
pass


class FolderListError(Exception):
pass


class ProcessInferenceResultError(Exception):
pass


class ValidateEnvVariablesError(Exception):
pass


class ServerError(Exception):
pass


class PipelineNotFoundError(Exception):
pass
4 changes: 2 additions & 2 deletions model/seed_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image
from collections import namedtuple
from urllib.request import Request, urlopen, HTTPError
from custom_exceptions import InferenceRequestError
from custom_exceptions import ProcessInferenceResultError

def process_image_slicing(image_bytes: bytes, result_json: dict) -> list:
"""
Expand Down Expand Up @@ -100,4 +100,4 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul
"images": process_image_slicing(previous_result, result_object)
}
except HTTPError as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
raise ProcessInferenceResultError(f"An error occurred while processing the request:\n {str(e)}") from None
4 changes: 2 additions & 2 deletions model/six_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
from collections import namedtuple
from urllib.request import Request, urlopen, HTTPError
from custom_exceptions import InferenceRequestError
from custom_exceptions import ProcessInferenceResultError

async def request_inference_from_nachet_6seeds(model: namedtuple, previous_result: str):
"""
Expand Down Expand Up @@ -48,4 +48,4 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul
return result_object

except HTTPError as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
raise ProcessInferenceResultError(f"An error occurred while processing the request:\n {str(e)}") from None
4 changes: 2 additions & 2 deletions model/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from collections import namedtuple
from urllib.request import Request, urlopen, HTTPError
from custom_exceptions import InferenceRequestError
from custom_exceptions import ProcessInferenceResultError


def process_swin_result(img_box:dict, results: dict) -> list:
Expand Down Expand Up @@ -59,4 +59,4 @@ async def request_inference_from_swin(model: namedtuple, previous_result: list[b

return process_swin_result(previous_result.get("result_json"), results)
except HTTPError as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}") from None
raise ProcessInferenceResultError(f"An error occurred while processing the request:\n {str(e)}") from None
Loading
Loading