diff --git a/pdfplumber/utils/geometry.py b/pdfplumber/utils/geometry.py index 407da602..0ac9b55b 100644 --- a/pdfplumber/utils/geometry.py +++ b/pdfplumber/utils/geometry.py @@ -1,52 +1,53 @@ import itertools from operator import itemgetter -from typing import Dict, List, Optional +from typing import Dict, Iterable, Optional from .._typing import T_bbox, T_num, T_obj, T_obj_list from .clustering import cluster_objects -from .generic import to_list -def objects_to_rect(objects: T_obj_list) -> Dict[str, T_num]: - return { - "x0": min(map(itemgetter("x0"), objects)), - "x1": max(map(itemgetter("x1"), objects)), - "top": min(map(itemgetter("top"), objects)), - "bottom": max(map(itemgetter("bottom"), objects)), - } +def objects_to_rect(objects: Iterable[T_obj]) -> Dict[str, T_num]: + """ + Given an iterable of objects, return the smallest rectangle (i.e. a + dict with "x0", "top", "x1", and "bottom" keys) that contains them + all. + """ + return bbox_to_rect(objects_to_bbox(objects)) -def objects_to_bbox(objects: T_obj_list) -> T_bbox: - return ( - min(map(itemgetter("x0"), objects)), - min(map(itemgetter("top"), objects)), - max(map(itemgetter("x1"), objects)), - max(map(itemgetter("bottom"), objects)), - ) +def objects_to_bbox(objects: Iterable[T_obj]) -> T_bbox: + """ + Given an iterable of objects, return the smallest bounding box that + contains them all. + """ + return merge_bboxes(map(bbox_getter, objects)) bbox_getter = itemgetter("x0", "top", "x1", "bottom") def obj_to_bbox(obj: T_obj) -> T_bbox: + """ + Return the bounding box for an object. + """ return bbox_getter(obj) def bbox_to_rect(bbox: T_bbox) -> Dict[str, T_num]: + """ + Return the rectangle (i.e a dict with keys "x0", "top", "x1", + "bottom") for an object. + """ return {"x0": bbox[0], "top": bbox[1], "x1": bbox[2], "bottom": bbox[3]} -def merge_bboxes(bboxes: List[T_bbox]) -> T_bbox: +def merge_bboxes(bboxes: Iterable[T_bbox]) -> T_bbox: """ - Given a set of bounding boxes, return the smallest bounding box that - contains them all. + Given an iterable of bounding boxes, return the smallest bounding box + that contains them all. """ - return ( - min(map(itemgetter(0), bboxes)), - min(map(itemgetter(1), bboxes)), - max(map(itemgetter(2), bboxes)), - max(map(itemgetter(3), bboxes)), - ) + x0, top, x1, bottom = zip(*bboxes) + return (min(x0), min(top), max(x1), max(bottom)) def get_bbox_overlap(a: T_bbox, b: T_bbox) -> Optional[T_bbox]: @@ -72,7 +73,6 @@ def calculate_area(bbox: T_bbox) -> T_num: def clip_obj(obj: T_obj, bbox: T_bbox) -> Optional[T_obj]: - overlap = get_bbox_overlap(obj_to_bbox(obj), bbox) if overlap is None: return None @@ -91,19 +91,14 @@ def clip_obj(obj: T_obj, bbox: T_bbox) -> Optional[T_obj]: return copy -def intersects_bbox(objs: T_obj_list, bbox: T_bbox) -> T_obj_list: +def intersects_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list: """ Filters objs to only those intersecting the bbox """ - initial_type = type(objs) - objs = to_list(objs) - matching = [ - obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is not None - ] - return initial_type(matching) + return [obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is not None] -def within_bbox(objs: T_obj_list, bbox: T_bbox) -> T_obj_list: +def within_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list: """ Filters objs to only those fully within the bbox """ @@ -114,14 +109,14 @@ def within_bbox(objs: T_obj_list, bbox: T_bbox) -> T_obj_list: ] -def outside_bbox(objs: T_obj_list, bbox: T_bbox) -> T_obj_list: +def outside_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list: """ Filters objs to only those fully outside the bbox """ return [obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is None] -def crop_to_bbox(objs: T_obj_list, bbox: T_bbox) -> T_obj_list: +def crop_to_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list: """ Filters objs to only those intersecting the bbox, and crops the extent of the objects to the bbox. @@ -151,10 +146,11 @@ def move_object(obj: T_obj, axis: str, value: T_num) -> T_obj: return obj.__class__(tuple(obj.items()) + tuple(new_items)) -def snap_objects(objs: T_obj_list, attr: str, tolerance: T_num) -> T_obj_list: +def snap_objects(objs: Iterable[T_obj], attr: str, tolerance: T_num) -> T_obj_list: axis = {"x0": "h", "x1": "h", "top": "v", "bottom": "v"}[attr] - clusters = cluster_objects(objs, itemgetter(attr), tolerance) - avgs = [sum(map(itemgetter(attr), objs)) / len(objs) for objs in clusters] + list_objs = list(objs) + clusters = cluster_objects(list_objs, itemgetter(attr), tolerance) + avgs = [sum(map(itemgetter(attr), cluster)) / len(cluster) for cluster in clusters] snapped_clusters = [ [move_object(obj, axis, avg - obj[attr]) for obj in cluster] for cluster, avg in zip(clusters, avgs) @@ -264,12 +260,11 @@ def obj_to_edges(obj: T_obj) -> T_obj_list: def filter_edges( - edges: T_obj_list, + edges: Iterable[T_obj], orientation: Optional[str] = None, edge_type: Optional[str] = None, min_length: T_num = 1, ) -> T_obj_list: - if orientation not in ("v", "h", None): raise ValueError("Orientation must be 'v' or 'h'") diff --git a/requirements-dev.txt b/requirements-dev.txt index ff8ab4d4..a2230c3c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,7 +5,7 @@ flake8==4.0.1 black==22.3.0 isort==5.10.1 pandas==2.0.3 -mypy==0.942 +mypy==0.981 pandas-stubs==1.2.0.58 types-Pillow==9.0.14 jupyterlab==3.4.2 diff --git a/tests/test_utils.py b/tests/test_utils.py index c0e7784a..cbd65a4b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -385,6 +385,7 @@ def test_intersects_bbox(self): bbox = utils.obj_to_bbox(objs[0]) assert utils.intersects_bbox(objs, bbox) == objs[:4] + assert utils.intersects_bbox(iter(objs), bbox) == objs[:4] def test_merge_bboxes(self): bboxes = [ @@ -393,6 +394,8 @@ def test_merge_bboxes(self): ] merged = utils.merge_bboxes(bboxes) assert merged == (0, 5, 20, 30) + merged = utils.merge_bboxes(iter(bboxes)) + assert merged == (0, 5, 20, 30) def test_resize_object(self): obj = { @@ -494,6 +497,8 @@ def test_snap_objects(self): a_new, b_new, c_new = utils.snap_objects([a, b, c], "x0", 1) assert a_new == b_new == c_new + a_new, b_new, c_new = utils.snap_objects(iter([a, b, c]), "x0", 1) + assert a_new == b_new == c_new def test_filter_edges(self): with pytest.raises(ValueError): @@ -515,6 +520,7 @@ def test_to_list(self): }, ] assert utils.to_list(objs) == objs + assert utils.to_list(iter(objs)) == objs assert utils.to_list(tuple(objs)) == objs assert utils.to_list((o for o in objs)) == objs assert utils.to_list(pd.DataFrame(objs)) == objs