generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added tests and examples for checkbox dectector
- Loading branch information
1 parent
3513db8
commit df8096d
Showing
16 changed files
with
1,362 additions
and
2 deletions.
There are no files selected for viewing
2 changes: 1 addition & 1 deletion
2
OCR/ocr/checkbox_detector_main.py → ...cr/checkbox_detector__subtraction_main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import cv2 | ||
import argparse | ||
from ocr.services.checkbox_image_align import SIFTAligner | ||
from ocr.services.checkbox_detector_image_extractor import MorphDataExtractor | ||
|
||
|
||
def extract_data(im, aligner, extractor): | ||
query = cv2.imread(im) | ||
aligned_sift = aligner.align(query) | ||
|
||
entities = extractor.extract(aligned_sift) | ||
|
||
return entities | ||
|
||
|
||
def main(args): | ||
query_image_location = args.query | ||
template_location = args.template | ||
roi_json_file = args.roi | ||
|
||
# roi_data = "./Users/arindamkulshi/IDWA/IDWA/OCR/tests/assets/checkbox1.json" | ||
|
||
morph_extractor = MorphDataExtractor(roi_json_file) | ||
|
||
# lifecell_template_path = "./Users/arindamkulshi/IDWA/IDWA/OCR/tests/assets/checkbox_template.png" | ||
|
||
sift_aligner = SIFTAligner(template_location) | ||
|
||
entities = extract_data(im=query_image_location, aligner=sift_aligner, extractor=morph_extractor) | ||
|
||
print(entities) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("Parsing arguments for information extraction") | ||
|
||
parser.add_argument("--query", type=str, help="Image from where the data needs to be extracted") | ||
|
||
parser.add_argument("--template", type=str, help="The template document which is used for alignment") | ||
|
||
parser.add_argument("--roi", type=str, help="json file that contains the roi data") | ||
|
||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import cv2 | ||
from abc import ABC, abstractmethod | ||
import json | ||
import numpy as np | ||
|
||
AREA_THRESHOLD = 40 | ||
|
||
|
||
class TRFDataExtractor(ABC): | ||
def __init__(self, roi_data) -> None: | ||
super().__init__() | ||
with open(roi_data, "rb") as roi_data_file: | ||
roi = json.load(roi_data_file) | ||
shapes = [] | ||
regions = [] | ||
for item in roi["data"]["regions"]: | ||
shapes.append(item["shape_attributes"]) | ||
regions.append(item["region_attributes"]) | ||
|
||
self._roi = list(zip(shapes, regions)) | ||
|
||
def _get_edges(self, im: cv2.Mat) -> cv2.Mat: | ||
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) | ||
blur = cv2.GaussianBlur(gray, (5, 5), 0) | ||
_, thresh = cv2.threshold(blur, 127, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | ||
|
||
edges = cv2.Canny(blur, 220, 250) | ||
|
||
return edges | ||
|
||
@abstractmethod | ||
def extract(self, aligned_image: cv2.Mat) -> list: | ||
pass | ||
|
||
|
||
class BoxDataExtractor(TRFDataExtractor): | ||
def __init__(self, roi_data, non_black_count_threshold=90) -> None: | ||
super().__init__(roi_data) | ||
self._non_black_count_threshold = non_black_count_threshold | ||
|
||
def extract(self, aligned_image: cv2.Mat) -> list: | ||
entity_list = [] | ||
|
||
edges = self._get_edges(aligned_image) | ||
|
||
for shape, region in self._roi: | ||
x = shape["x"] | ||
y = shape["y"] | ||
w = shape["width"] | ||
h = shape["height"] | ||
|
||
im = edges[y : y + h, x : x + w] | ||
non_black_count = cv2.countNonZero(im) | ||
# print( | ||
# f'Region: {region["type"]}, Non Black Count: {non_black_count}') | ||
if non_black_count > self._non_black_count_threshold: | ||
entity_list.append(region["type"]) | ||
|
||
return entity_list | ||
|
||
|
||
class MorphDataExtractor(TRFDataExtractor): | ||
def __init__(self, roi_data, non_black_count_threshold=0) -> None: | ||
super().__init__(roi_data) | ||
self._non_black_count_threshold = non_black_count_threshold | ||
|
||
def _clean_image(self, im: cv2.Mat) -> cv2.Mat: | ||
cnts, _ = cv2.findContours(im, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] | ||
|
||
for c in cnts: | ||
area = cv2.contourArea(c) | ||
if area < AREA_THRESHOLD: | ||
cv2.drawContours(im, [c], -1, 0, -1) | ||
|
||
return im | ||
|
||
def extract(self, aligned_image: cv2.Mat) -> list: | ||
entity_list = [] | ||
gray = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2GRAY) | ||
|
||
for shape, region in self._roi: | ||
x = shape["x"] | ||
y = shape["y"] | ||
w = shape["width"] | ||
h = shape["height"] | ||
|
||
im = gray[y : y + h, x : x + w] | ||
|
||
thresh = cv2.adaptiveThreshold(im, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 10) | ||
thresh = self._clean_image(thresh) | ||
|
||
repair_kernel_horizontal = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 1)) | ||
repair_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, repair_kernel_horizontal, iterations=2) | ||
|
||
horizontal_lines = cv2.HoughLinesP(repair_horizontal, 1, np.pi / 180, 20, 10, 10) | ||
if horizontal_lines is not None: | ||
for line in horizontal_lines: | ||
for x1, y1, x2, y2 in line: | ||
theta = np.arctan((y2 - y1) / (x2 - x1 + 1e-8)) * 180 / np.pi | ||
if theta > -6 and theta < 6: | ||
cv2.line(thresh, (x1, y1), (x2, y2), (0, 0, 0), 2) | ||
|
||
repair_kernel_vertical = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 4)) | ||
repair_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, repair_kernel_vertical, iterations=1) | ||
|
||
vertical_lines = cv2.HoughLinesP(repair_vertical, 1, np.pi / 180, 12, 40, 4) | ||
if vertical_lines is not None: | ||
for line in vertical_lines: | ||
for x1, y1, x2, y2 in line: | ||
theta = np.abs(np.arctan((y2 - y1) / (x2 - x1 + 1e-8)) * 180 / np.pi) | ||
if theta > 84 and theta < 96: | ||
cv2.line(thresh, (x1, y1), (x2, y2), (0, 0, 0), 2) | ||
|
||
thresh = self._clean_image(thresh) | ||
|
||
non_black_count = cv2.countNonZero(thresh) | ||
print(f"ROI: {region['type']}, Non-Black Count: {non_black_count}") | ||
|
||
if non_black_count > self._non_black_count_threshold: | ||
entity_list.append(region["type"]) | ||
|
||
return entity_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from abc import ABC, abstractmethod | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class ImageAligner(ABC): | ||
def __init__(self, template: str) -> None: | ||
super().__init__() | ||
template_im: cv2.Mat = cv2.imread(template) | ||
self._template: cv2.Mat = cv2.cvtColor(template_im, cv2.COLOR_BGR2GRAY) | ||
|
||
@abstractmethod | ||
def align(self, query_image: cv2.Mat) -> cv2.Mat: | ||
pass | ||
|
||
|
||
class SIFTAligner(ImageAligner): | ||
def __init__(self, template: str) -> None: | ||
super().__init__(template) | ||
self._h, self._w = self._template.shape | ||
self._sift = cv2.SIFT_create() | ||
self._kps_template, self._desc_template = self._sift.detectAndCompute(self._template, None) | ||
self._matcher = cv2.BFMatcher() | ||
|
||
# def align(self, query_image: str) -> cv2.Mat: | ||
# query = cv2.imread(query_image) | ||
# self.align(query) | ||
|
||
def align(self, query_image: cv2.Mat) -> cv2.Mat: | ||
# query_image = cv2.imread(query_image_path) | ||
query = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY) | ||
kps_query, desc_query = self._sift.detectAndCompute(query, None) | ||
|
||
matches = self._matcher.knnMatch(desc_query, self._desc_template, k=2) | ||
# Apply ratio test | ||
good = [] | ||
for m, n in matches: | ||
if m.distance < 0.75 * n.distance: | ||
good.append([m]) | ||
|
||
pts_template = np.zeros((len(good), 2), dtype="float") | ||
pts_query = np.zeros((len(good), 2), dtype="float") | ||
|
||
for i, m in enumerate(good): | ||
try: | ||
pts_query[i, :] = kps_query[m[0].queryIdx].pt | ||
pts_template[i, :] = self._kps_template[m[0].trainIdx].pt | ||
except IndexError: | ||
print(f"Index Error at {i}") | ||
|
||
H, mask = cv2.findHomography(pts_query, pts_template, method=cv2.RANSAC) | ||
aligned = cv2.warpPerspective(query_image, H, (self._w, self._h)) | ||
|
||
return aligned | ||
|
||
|
||
class ORBAligner(ImageAligner): | ||
def __init__(self, template: str, max_features=3000, keep_percent=0.2) -> None: | ||
super().__init__(template) | ||
self._h, self._w = self._template.shape | ||
self._orb = cv2.ORB_create(max_features) | ||
|
||
self._kps_template, self._desc_template = self._orb.detectAndCompute(self._template, None) | ||
self._matcher = cv2.BFMatcher(cv2.NORM_HAMMING) | ||
self._keep_percent = keep_percent | ||
|
||
def align(self, query_image: cv2.Mat) -> cv2.Mat: | ||
query = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY) | ||
kps_query, desc_query = self._orb.detectAndCompute(query, None) | ||
|
||
matches = self._matcher.match(desc_query, self._desc_template) | ||
|
||
matches = sorted(matches, key=lambda x: x.distance, reverse=False) | ||
|
||
keep = int(len(matches) * self._keep_percent) | ||
matches = matches[:keep] | ||
|
||
pts_template = np.zeros((len(matches), 2), dtype="float") | ||
pts_query = np.zeros((len(matches), 2), dtype="float") | ||
|
||
for i, m in enumerate(matches): | ||
pts_query[i, :] = kps_query[m.queryIdx].pt | ||
pts_template[i, :] = self._kps_template[m.trainIdx].pt | ||
|
||
H, mask = cv2.findHomography(pts_query, pts_template, method=cv2.RANSAC) | ||
aligned_orb = cv2.warpPerspective(query_image, H, (self._w, self._h)) | ||
|
||
return aligned_orb |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.