Skip to content

Commit

Permalink
added tests and examples for checkbox dectector
Browse files Browse the repository at this point in the history
  • Loading branch information
arinkulshi committed May 23, 2024
1 parent 3513db8 commit df8096d
Show file tree
Hide file tree
Showing 16 changed files with 1,362 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ocr.services.checkbox_detector import CheckboxDetector
from ocr.services.checkbox_detector_subtraction import CheckboxDetector

import os

Expand Down
45 changes: 45 additions & 0 deletions OCR/ocr/checkbox_detector_vision_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import cv2
import argparse
from ocr.services.checkbox_image_align import SIFTAligner
from ocr.services.checkbox_detector_image_extractor import MorphDataExtractor


def extract_data(im, aligner, extractor):
query = cv2.imread(im)
aligned_sift = aligner.align(query)

entities = extractor.extract(aligned_sift)

return entities


def main(args):
query_image_location = args.query
template_location = args.template
roi_json_file = args.roi

# roi_data = "./Users/arindamkulshi/IDWA/IDWA/OCR/tests/assets/checkbox1.json"

morph_extractor = MorphDataExtractor(roi_json_file)

# lifecell_template_path = "./Users/arindamkulshi/IDWA/IDWA/OCR/tests/assets/checkbox_template.png"

sift_aligner = SIFTAligner(template_location)

entities = extract_data(im=query_image_location, aligner=sift_aligner, extractor=morph_extractor)

print(entities)


if __name__ == "__main__":
parser = argparse.ArgumentParser("Parsing arguments for information extraction")

parser.add_argument("--query", type=str, help="Image from where the data needs to be extracted")

parser.add_argument("--template", type=str, help="The template document which is used for alignment")

parser.add_argument("--roi", type=str, help="json file that contains the roi data")

args = parser.parse_args()

main(args)
122 changes: 122 additions & 0 deletions OCR/ocr/services/checkbox_detector_image_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import cv2
from abc import ABC, abstractmethod
import json
import numpy as np

AREA_THRESHOLD = 40


class TRFDataExtractor(ABC):
def __init__(self, roi_data) -> None:
super().__init__()
with open(roi_data, "rb") as roi_data_file:
roi = json.load(roi_data_file)
shapes = []
regions = []
for item in roi["data"]["regions"]:
shapes.append(item["shape_attributes"])
regions.append(item["region_attributes"])

self._roi = list(zip(shapes, regions))

def _get_edges(self, im: cv2.Mat) -> cv2.Mat:
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (5, 5), 0)
_, thresh = cv2.threshold(blur, 127, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

edges = cv2.Canny(blur, 220, 250)

return edges

@abstractmethod
def extract(self, aligned_image: cv2.Mat) -> list:
pass


class BoxDataExtractor(TRFDataExtractor):
def __init__(self, roi_data, non_black_count_threshold=90) -> None:
super().__init__(roi_data)
self._non_black_count_threshold = non_black_count_threshold

def extract(self, aligned_image: cv2.Mat) -> list:
entity_list = []

edges = self._get_edges(aligned_image)

for shape, region in self._roi:
x = shape["x"]
y = shape["y"]
w = shape["width"]
h = shape["height"]

im = edges[y : y + h, x : x + w]
non_black_count = cv2.countNonZero(im)
# print(
# f'Region: {region["type"]}, Non Black Count: {non_black_count}')
if non_black_count > self._non_black_count_threshold:
entity_list.append(region["type"])

return entity_list


class MorphDataExtractor(TRFDataExtractor):
def __init__(self, roi_data, non_black_count_threshold=0) -> None:
super().__init__(roi_data)
self._non_black_count_threshold = non_black_count_threshold

def _clean_image(self, im: cv2.Mat) -> cv2.Mat:
cnts, _ = cv2.findContours(im, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]

for c in cnts:
area = cv2.contourArea(c)
if area < AREA_THRESHOLD:
cv2.drawContours(im, [c], -1, 0, -1)

return im

def extract(self, aligned_image: cv2.Mat) -> list:
entity_list = []
gray = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2GRAY)

for shape, region in self._roi:
x = shape["x"]
y = shape["y"]
w = shape["width"]
h = shape["height"]

im = gray[y : y + h, x : x + w]

thresh = cv2.adaptiveThreshold(im, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 10)
thresh = self._clean_image(thresh)

repair_kernel_horizontal = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 1))
repair_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, repair_kernel_horizontal, iterations=2)

horizontal_lines = cv2.HoughLinesP(repair_horizontal, 1, np.pi / 180, 20, 10, 10)
if horizontal_lines is not None:
for line in horizontal_lines:
for x1, y1, x2, y2 in line:
theta = np.arctan((y2 - y1) / (x2 - x1 + 1e-8)) * 180 / np.pi
if theta > -6 and theta < 6:
cv2.line(thresh, (x1, y1), (x2, y2), (0, 0, 0), 2)

repair_kernel_vertical = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 4))
repair_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, repair_kernel_vertical, iterations=1)

vertical_lines = cv2.HoughLinesP(repair_vertical, 1, np.pi / 180, 12, 40, 4)
if vertical_lines is not None:
for line in vertical_lines:
for x1, y1, x2, y2 in line:
theta = np.abs(np.arctan((y2 - y1) / (x2 - x1 + 1e-8)) * 180 / np.pi)
if theta > 84 and theta < 96:
cv2.line(thresh, (x1, y1), (x2, y2), (0, 0, 0), 2)

thresh = self._clean_image(thresh)

non_black_count = cv2.countNonZero(thresh)
print(f"ROI: {region['type']}, Non-Black Count: {non_black_count}")

if non_black_count > self._non_black_count_threshold:
entity_list.append(region["type"])

return entity_list
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def is_checked(self, image_path, threshold=18000):

diff = cv2.absdiff(self.template, checkbox_image)

# Calculate the sum of the absolute differences
# Calculate difference
diff_sum = np.sum(diff)
print(diff_sum)
print(threshold)
Expand Down
88 changes: 88 additions & 0 deletions OCR/ocr/services/checkbox_image_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from abc import ABC, abstractmethod
import cv2
import numpy as np


class ImageAligner(ABC):
def __init__(self, template: str) -> None:
super().__init__()
template_im: cv2.Mat = cv2.imread(template)
self._template: cv2.Mat = cv2.cvtColor(template_im, cv2.COLOR_BGR2GRAY)

@abstractmethod
def align(self, query_image: cv2.Mat) -> cv2.Mat:
pass


class SIFTAligner(ImageAligner):
def __init__(self, template: str) -> None:
super().__init__(template)
self._h, self._w = self._template.shape
self._sift = cv2.SIFT_create()
self._kps_template, self._desc_template = self._sift.detectAndCompute(self._template, None)
self._matcher = cv2.BFMatcher()

# def align(self, query_image: str) -> cv2.Mat:
# query = cv2.imread(query_image)
# self.align(query)

def align(self, query_image: cv2.Mat) -> cv2.Mat:
# query_image = cv2.imread(query_image_path)
query = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY)
kps_query, desc_query = self._sift.detectAndCompute(query, None)

matches = self._matcher.knnMatch(desc_query, self._desc_template, k=2)
# Apply ratio test
good = []
for m, n in matches:
if m.distance < 0.75 * n.distance:
good.append([m])

pts_template = np.zeros((len(good), 2), dtype="float")
pts_query = np.zeros((len(good), 2), dtype="float")

for i, m in enumerate(good):
try:
pts_query[i, :] = kps_query[m[0].queryIdx].pt
pts_template[i, :] = self._kps_template[m[0].trainIdx].pt
except IndexError:
print(f"Index Error at {i}")

H, mask = cv2.findHomography(pts_query, pts_template, method=cv2.RANSAC)
aligned = cv2.warpPerspective(query_image, H, (self._w, self._h))

return aligned


class ORBAligner(ImageAligner):
def __init__(self, template: str, max_features=3000, keep_percent=0.2) -> None:
super().__init__(template)
self._h, self._w = self._template.shape
self._orb = cv2.ORB_create(max_features)

self._kps_template, self._desc_template = self._orb.detectAndCompute(self._template, None)
self._matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
self._keep_percent = keep_percent

def align(self, query_image: cv2.Mat) -> cv2.Mat:
query = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY)
kps_query, desc_query = self._orb.detectAndCompute(query, None)

matches = self._matcher.match(desc_query, self._desc_template)

matches = sorted(matches, key=lambda x: x.distance, reverse=False)

keep = int(len(matches) * self._keep_percent)
matches = matches[:keep]

pts_template = np.zeros((len(matches), 2), dtype="float")
pts_query = np.zeros((len(matches), 2), dtype="float")

for i, m in enumerate(matches):
pts_query[i, :] = kps_query[m.queryIdx].pt
pts_template[i, :] = self._kps_template[m.trainIdx].pt

H, mask = cv2.findHomography(pts_query, pts_template, method=cv2.RANSAC)
aligned_orb = cv2.warpPerspective(query_image, H, (self._w, self._h))

return aligned_orb
Binary file removed OCR/ocr/services/checkbox_template.png
Binary file not shown.
Binary file added OCR/tests/assets/TRF.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit df8096d

Please sign in to comment.