Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement alignment service #313

Merged
merged 12 commits into from
Oct 28, 2024
1 change: 1 addition & 0 deletions OCR/ocr/services/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .image_alignment import ImageAligner as ImageAligner
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@


class FourPointTransform:
def __init__(self, image: Path):
self.image = cv.imread(str(image), cv.IMREAD_GRAYSCALE)
def __init__(self, image: Path | np.ndarray):
if isinstance(image, np.ndarray):
self.image = image
else:
self.image = cv.imread(str(image))

@classmethod
def align(self, source_image, template_image):
return FourPointTransform(source_image).dewarp()

@staticmethod
def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
Expand All @@ -28,7 +35,9 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray:

def find_largest_contour(self):
"""Compute contours for an image and find the biggest one by area."""
_, contours, _ = cv.findContours(self.image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
contours, _ = cv.findContours(
cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE
)
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours)

def simplify_polygon(self, contour):
Expand All @@ -40,8 +49,8 @@ def dewarp(self) -> np.ndarray:
biggest_contour = self.find_largest_contour()
simplified = self.simplify_polygon(biggest_contour)

height, width = self.image.shape
height, width, _ = self.image.shape
destination = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)

M = cv.getPerspectiveTransform(self.order_points(simplified), destination)
M = cv.getPerspectiveTransform(self._order_points(simplified), destination)
return cv.warpPerspective(self.image, M, (width, height))
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@


class ImageHomography:
def __init__(self, template: Path, match_ratio=0.3):
def __init__(self, template: Path | np.ndarray, match_ratio=0.3):
"""Initialize the image homography pipeline with a `template` image."""
if match_ratio >= 1 or match_ratio <= 0:
raise ValueError("`match_ratio` must be between 0 and 1")

self.template = cv.imread(template)
if isinstance(template, np.ndarray):
self.template = template
else:
self.template = cv.imread(template)
self.match_ratio = match_ratio
self._sift = cv.SIFT_create()

@classmethod
def align(self, source_image, template_image):
return ImageHomography(template_image).transform_homography(source_image)

def estimate_self_similarity(self):
"""Calibrate `match_ratio` using a self-similarity metric."""
raise NotImplementedError
Expand Down Expand Up @@ -48,9 +55,25 @@ def estimate_transform_matrix(self, other):
M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0)
return M

def transform_homography(self, other, matrix=None):
"""Run the image homography pipeline against a query image."""
def transform_homography(self, other, min_axis=100, matrix=None):
"""
Run the image homography pipeline against a query image.

Parameters:
min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform.
If the input image is under the axis limits, return the original input image unchanged.
matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be
estimated with `estimate_transform_matrix`.
"""

if other.shape[0] < min_axis and other.shape[1] < min_axis:
return other

if matrix is None:
matrix = self.estimate_transform_matrix(other)
try:
matrix = self.estimate_transform_matrix(other)
except cv.error:
print("could not estimate transform matrix")
return other

return cv.warpPerspective(other, matrix, (self.template.shape[1], self.template.shape[0]))
19 changes: 19 additions & 0 deletions OCR/ocr/services/alignment/image_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

from ocr.services.alignment.backends import ImageHomography


class ImageAligner:
def __init__(self, aligner=ImageHomography):
self.aligner = aligner

def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""
Aligns an image using the specified image alignment backend.

source_image: the image to be aligned, as a numpy ndarray.
template_image: the image that `source_image` will be aligned against, as a numpy ndarray.
May not be used for all image alignment backends.
"""
aligned_image = self.aligner.align(source_image, template_image)
return aligned_image
13 changes: 12 additions & 1 deletion OCR/tests/alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import cv2 as cv
import numpy as np
import pytest

from alignment import ImageHomography, RandomPerspectiveTransform
from ocr.services.alignment.backends import FourPointTransform, ImageHomography, RandomPerspectiveTransform
from ocr.services.alignment import ImageAligner


path = os.path.dirname(__file__)
Expand All @@ -14,6 +16,15 @@


class TestAlignment:
@pytest.mark.parametrize("align_class", [ImageHomography, FourPointTransform])
def test_align_implementation(self, align_class):
"""Tests that the ImageAligner class backends implement the `align` method."""
template_image = cv.imread(template_image_path)
aligner = ImageAligner(aligner=align_class)
result = aligner.align(filled_image, template_image)
assert result.shape == template_image.shape, "Aliged image has wrong shape"
assert np.median(cv.absdiff(template_image, result)) <= 1, "Median difference too high"

def test_random_warp(self):
"""Test that a random warp generates an image different from the template."""
transformed = RandomPerspectiveTransform(filled_image_path).random_transform(distortion_scale=0.1)
Expand Down
Loading