Skip to content

Commit

Permalink
Added Error Handling to image_to_text API (#385)
Browse files Browse the repository at this point in the history
* added error handling logic

* added tests for error handling logic
  • Loading branch information
arinkulshi-skylight authored Nov 14, 2024
1 parent fddc049 commit f9eb032
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 10 deletions.
53 changes: 43 additions & 10 deletions OCR/ocr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import json
import cv2 as cv
import numpy as np
import asyncio

from fastapi import FastAPI, UploadFile, Form

from fastapi import FastAPI, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware

from ocr.services.image_ocr import ImageOCR
Expand Down Expand Up @@ -60,15 +62,46 @@ async def image_alignment(source_image: str = Form(), segmentation_template: str

@app.post("/image_file_to_text/")
async def image_file_to_text(source_image: UploadFile, segmentation_template: UploadFile, labels: str = Form()):
source_image_np = np.frombuffer(await source_image.read(), np.uint8)
source_image_img = cv.imdecode(source_image_np, cv.IMREAD_COLOR)

segmentation_template_np = np.frombuffer(await segmentation_template.read(), np.uint8)
segmentation_template_img = cv.imdecode(segmentation_template_np, cv.IMREAD_COLOR)

loaded_json = json.loads(labels)
segments = segmenter.segment(source_image_img, segmentation_template_img, loaded_json)
results = ocr.image_to_text(segments)
try:
source_image_np = np.frombuffer(await source_image.read(), np.uint8)
source_image_img = cv.imdecode(source_image_np, cv.IMREAD_COLOR)

if source_image_img is None:
raise HTTPException(
status_code=422, detail="Failed to decode source image. Ensure the file is a valid image format."
)

segmentation_template_np = np.frombuffer(await segmentation_template.read(), np.uint8)
segmentation_template_img = cv.imdecode(segmentation_template_np, cv.IMREAD_COLOR)

if segmentation_template_img is None:
raise HTTPException(
status_code=422,
detail="Failed to decode segmentation template. Ensure the file is a valid image format.",
)

if source_image_img.shape[:2] != segmentation_template_img.shape[:2]:
raise HTTPException(
status_code=400,
detail="Dimension mismatch between source image and segmentation template. Both images must have the same width and height.",
)

loaded_json = json.loads(labels)

segments = segmenter.segment(source_image_img, segmentation_template_img, loaded_json)
results = ocr.image_to_text(segments)

except json.JSONDecodeError:
raise HTTPException(
status_code=422, detail="Failed to parse labels JSON. Ensure the labels are in valid JSON format."
)
except asyncio.TimeoutError:
raise HTTPException(status_code=504, detail="The request timed out. Please try again.")
except HTTPException as e:
raise e
except Exception as e:
print(f"Unexpected error occurred: {str(e)}")
raise HTTPException(status_code=500, detail="An unexpected server error occurred.")

return results

Expand Down
114 changes: 114 additions & 0 deletions OCR/tests/api_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import base64
import os
import json
from unittest import mock
import asyncio


from fastapi.testclient import TestClient

Expand All @@ -13,6 +16,8 @@
segmentation_template_path = os.path.join(path, "./assets/form_segmention_template.png")
source_image_path = os.path.join(path, "./assets/form_filled.png")
labels_path = os.path.join(path, "./assets/labels.json")
invalid_dimension_path = os.path.join(path, "./assets/invalid_dimension_template.png")
invalid_image_file_path = os.path.join(path, "./assets/invalid_image_file.png")


class TestAPI:
Expand Down Expand Up @@ -104,3 +109,112 @@ def test_image_to_text_with_padding(self):
response_json = response.json()
assert response_json["nbs_patient_id"][0] == "SIENNA HAMPTON"
assert response_json["nbs_cas_id"][0] == "123555"

def test_invalid_source_image_format(self):
with (
open(segmentation_template_path, "rb") as segmentation_template_file,
open(invalid_image_file_path, "rb") as source_image_file, # using invalid image
open(labels_path, "r") as labels,
):
label_data = json.load(labels)
files_to_send = [
("source_image", source_image_file),
("segmentation_template", segmentation_template_file),
]

response = client.post(
url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)}
)

assert response.status_code == 422
response_json = response.json()
assert (
response_json["detail"] == "Failed to decode source image. Ensure the file is a valid image format."
)

def test_invalid_segmentation_template_format(self):
with (
open(invalid_image_file_path, "rb") as segmentation_template_file, # using a invalid image
open(source_image_path, "rb") as source_image_file,
open(labels_path, "r") as labels,
):
label_data = json.load(labels)
files_to_send = [
("source_image", source_image_file),
("segmentation_template", segmentation_template_file),
]

response = client.post(
url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)}
)

assert response.status_code == 422
assert (
response.json()["detail"]
== "Failed to decode segmentation template. Ensure the file is a valid image format."
)

def test_dimension_mismatch(self):
with (
open(source_image_path, "rb") as source_image_file,
open(invalid_dimension_path, "rb") as invalid_dimension_file, # using a file with separate dimensions
open(labels_path, "r") as labels,
):
label_data = json.load(labels)
files_to_send = [
("source_image", source_image_file),
("segmentation_template", invalid_dimension_file),
]

response = client.post(
url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)}
)

assert response.status_code == 400
assert (
response.json()["detail"]
== "Dimension mismatch between source image and segmentation template. Both images must have the same width and height."
)

def test_invalid_json_labels(self):
with (
open(source_image_path, "rb") as source_image_file,
open(segmentation_template_path, "rb") as segmentation_template_file,
):
invalid_label_data = "{invalid: json}" # file with invalid json format
files_to_send = [
("source_image", source_image_file),
("segmentation_template", segmentation_template_file),
]

response = client.post(
url="/image_file_to_text", files=files_to_send, data={"labels": invalid_label_data}
)

assert response.status_code == 422
assert (
response.json()["detail"]
== "Failed to parse labels JSON. Ensure the labels are in valid JSON format."
)

def test_timeout_error_simulation(self):
with (
open(source_image_path, "rb") as source_image_file,
open(segmentation_template_path, "rb") as segmentation_template_file,
open(labels_path, "r") as labels,
):
label_data = json.load(labels)
files_to_send = [
("source_image", source_image_file),
("segmentation_template", segmentation_template_file),
]

with mock.patch(
"ocr.services.image_segmenter.ImageSegmenter.segment", side_effect=asyncio.TimeoutError
): # mocks a invoked segment call
response = client.post(
url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)}
)

assert response.status_code == 504
assert response.json()["detail"] == "The request timed out. Please try again."
Binary file added OCR/tests/assets/invalid_dimension_template.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added OCR/tests/assets/invalid_image_file.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit f9eb032

Please sign in to comment.