diff --git a/rows/plugins/plugin_ocr.py b/rows/plugins/plugin_ocr.py index 89501427..b60a8892 100644 --- a/rows/plugins/plugin_ocr.py +++ b/rows/plugins/plugin_ocr.py @@ -21,10 +21,56 @@ from pytesseract import image_to_boxes from PIL import Image -from rows.plugins.plugin_pdf import PDFBackend, TextObject, pdf_table_lines +from rows.plugins.plugin_pdf import group_objects, PDFBackend, TextObject, pdf_table_lines from rows.plugins.utils import create_table +def join_text_group(group): + """Join a list of `TextObject`s into one""" + + obj = group[0] + max_between = (obj.x1 - obj.x0) / len(obj.text) # Average letter size + text, last_x1 = [], obj.x0 + for obj in group: + if last_x1 + max_between <= obj.x0: + text.append(" ") + text.append(obj.text) + last_x1 = obj.x1 + text = "".join(text) + + return TextObject( + x0=min(obj.x0 for obj in group), + y0=min(obj.y0 for obj in group), + x1=max(obj.x1 for obj in group), + y1=max(obj.y1 for obj in group), + text=text + ) + + +def group_contiguous_objects(objs, x_threshold, y_threshold): + """Merge contiguous objects if they're closer enough""" + + objs.sort(key=lambda obj: obj.y0) + y_groups = group_objects(objs, y_threshold, "y") + for y_group, y_items in y_groups.items(): + y_items.sort(key=lambda obj: obj.x0) + + x_groups, current_group, last_x1 = [], [], None + for obj in y_items: + if not current_group or last_x1 + x_threshold >= obj.x0: + current_group.append(obj) + elif current_group: + x_groups.append(current_group) + current_group = [obj] + last_x1 = obj.x1 + if current_group: + x_groups.append(current_group) + + for group in x_groups: + if group: + yield join_text_group(group) + + class TesseractBackend(PDFBackend): name = "tesseract" @@ -36,12 +82,7 @@ def __init__(self, filename_or_fobj, language): @cached_property def document(self): - if hasattr(self.filename_or_fobj, "read"): - image = Image.open(self.filename_or_fobj) - else: - image = self.filename_or_fobj - - return image + return Image.open(self.filename_or_fobj) @cached_property def number_of_pages(self): @@ -51,6 +92,7 @@ def extract_text(self, page_numbers=None): return "" # TODO: image_to_string def objects(self, page_numbers=None, starts_after=None, ends_before=None): + _, total_y = self.document.size header = "char left bottom right top page".split() boxes = image_to_boxes(self.document, lang=self.language).splitlines() text_objs = [] @@ -60,18 +102,18 @@ def objects(self, page_numbers=None, starts_after=None, ends_before=None): if key != "char": value = int(value) row[key] = value - obj = TextObject( - x0=row["left"], - y0=row["bottom"], - x1=row["right"], - y1=row["top"], - text=row["char"], + text_objs.append( + TextObject( + x0=row["left"], + y0=total_y - row["bottom"], + x1=row["right"], + y1=total_y - row["top"], + text=row["char"], + ) ) - text_objs.append(obj) - text_objs.sort(key=lambda obj: (obj.y0, obj.x0)) - # TODO: group contiguous objects before yielding - yield text_objs + # TODO: custom thresholds + yield list(group_contiguous_objects(text_objs, 30, 12)) text_objects = objects