From 35272e724ef1c48304de0e400478585d816dfeb1 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Thu, 18 Apr 2024 16:24:16 -0700 Subject: [PATCH 1/4] add helper function for getting points from geometry --- sleap_roots/convhull.py | 19 ++++++++++------- sleap_roots/points.py | 43 ++++++++++++++++++++++++++++++++++++- tests/test_convhull.py | 1 - tests/test_points.py | 47 +++++++++++++++++++++++++++++++++++++++-- 4 files changed, 98 insertions(+), 12 deletions(-) diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py index c15a169..12b3df2 100644 --- a/sleap_roots/convhull.py +++ b/sleap_roots/convhull.py @@ -4,7 +4,7 @@ from scipy.spatial import ConvexHull from scipy.spatial.distance import pdist from typing import Tuple, Optional, Union -from sleap_roots.points import get_line_equation_from_points +from sleap_roots.points import extract_points_from_geometry, get_line_equation_from_points from shapely import box, LineString, normalize, Polygon @@ -382,13 +382,9 @@ def get_chull_areas_via_intersection( # Find the intersection between the hull perimeter and the extended line intersection = extended_line.intersection(hull_perimeter) - # Add intersection points to both lists + # Compute the intersection points and add to lists if not intersection.is_empty: - intersect_points = ( - np.array([[point.x, point.y] for point in intersection.geoms]) - if intersection.geom_type == "MultiPoint" - else np.array([[intersection.x, intersection.y]]) - ) + intersect_points = extract_points_from_geometry(intersection) above_line.extend(intersect_points) below_line.extend(intersect_points) @@ -452,6 +448,10 @@ def get_chull_intersection_vectors( Raises: ValueError: If pts does not have the expected shape. """ + if r0_pts.ndim == 1 or rn_pts.ndim == 1 or pts.ndim == 2: + print("Not enough instances or incorrect format to compute convex hull intersections.") + return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) + # Check for valid pts input if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2: raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).") @@ -460,7 +460,7 @@ def get_chull_intersection_vectors( raise ValueError("rn_pts must be a numpy array of shape (instances, 2).") # Ensure r0_pts is a numpy array of shape (instances, 2) if not isinstance(r0_pts, np.ndarray) or r0_pts.ndim != 2 or r0_pts.shape[-1] != 2: - raise ValueError("r0_pts must be a numpy array of shape (instances, 2).") + raise ValueError(f"r0_pts must be a numpy array of shape (instances, 2).") # Flatten pts to 2D array and remove NaN values flattened_pts = pts.reshape(-1, 2) @@ -481,6 +481,9 @@ def get_chull_intersection_vectors( # Ensuring r0_pts does not contain NaN values r0_pts_valid = r0_pts[~np.isnan(r0_pts).any(axis=1)] + # Expect two vectors in the end + if len(r0_pts_valid) < 2: + return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) # Get the vertices of the convex hull hull_vertices = hull.points[hull.vertices] diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 479f564..22bdfcf 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -3,11 +3,52 @@ import numpy as np from matplotlib import pyplot as plt from matplotlib.lines import Line2D -from shapely.geometry import LineString +from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection from shapely.ops import nearest_points from typing import List, Optional, Tuple +def extract_points_from_geometry(geometry): + """Extracts coordinates as a list of numpy arrays from any given Shapely geometry object. + + This function supports Point, MultiPoint, LineString, and GeometryCollection types. + It recursively extracts coordinates from complex geometries and aggregates them into a single list. + For unsupported geometry types, it returns an empty list. + + Parameters: + - geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points. + + Returns: + - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. + The list will be empty if the geometry type is unsupported or contains no coordinates. + + Raises: + - TypeError: If the input is not a recognized Shapely geometry type. + + Example: + >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection + >>> point = Point(1, 2) + >>> multipoint = MultiPoint([(1, 2), (3, 4)]) + >>> linestring = LineString([(0, 0), (1, 1), (2, 2)]) + >>> geom_col = GeometryCollection([point, multipoint, linestring]) + >>> extract_points_from_geometry(geom_col) + [array([1, 2]), array([1, 2]), array([3, 4]), array([0, 0]), array([1, 1]), array([2, 2])] + """ + if isinstance(geometry, Point): + return [np.array([geometry.x, geometry.y])] + elif isinstance(geometry, MultiPoint): + return [np.array([point.x, point.y]) for point in geometry.geoms] + elif isinstance(geometry, LineString): + return [np.array([x, y]) for x, y in zip(*geometry.xy)] + elif isinstance(geometry, GeometryCollection): + points = [] + for geom in geometry.geoms: + points.extend(extract_points_from_geometry(geom)) + return points + else: + raise TypeError(f"Unsupported geometry type: {type(geometry).__name__}") + + def get_count(pts: np.ndarray): """Get number of roots. diff --git a/tests/test_convhull.py b/tests/test_convhull.py index a33c279..f506312 100644 --- a/tests/test_convhull.py +++ b/tests/test_convhull.py @@ -314,7 +314,6 @@ def test_basic_functionality(pts_shape_3_6_2): @pytest.mark.parametrize( "invalid_input", [ - (np.array([1, 2]), np.array([3, 4]), np.array([[[1, 2], [3, 4]]]), None), (np.array([[1, 2, 3]]), np.array([[3, 4]]), np.array([[[1, 2], [3, 4]]]), None), # Add more invalid inputs as needed ], diff --git a/tests/test_points.py b/tests/test_points.py index 6ac3a1d..ed042a8 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -1,9 +1,9 @@ import numpy as np import pytest -from shapely.geometry import LineString +from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection from sleap_roots import Series from sleap_roots.lengths import get_max_length_pts -from sleap_roots.points import filter_plants_with_unexpected_ct, get_count, join_pts +from sleap_roots.points import extract_points_from_geometry, filter_plants_with_unexpected_ct, get_count, join_pts from sleap_roots.points import ( get_all_pts_array, get_nodes, @@ -738,3 +738,46 @@ def test_filter_plants_with_unexpected_ct_incorrect_input_types(): expected_count = "not a float" with pytest.raises(ValueError): filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count) + + +def test_extract_from_point(): + point = Point(1, 2) + expected = [np.array([1, 2])] + assert np.array_equal(extract_points_from_geometry(point), expected) + +def test_extract_from_multipoint(): + multipoint = MultiPoint([(1, 2), (3, 4)]) + expected = [np.array([1, 2]), np.array([3, 4])] + results = extract_points_from_geometry(multipoint) + assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + +def test_extract_from_linestring(): + linestring = LineString([(0, 0), (1, 1), (2, 2)]) + expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])] + results = extract_points_from_geometry(linestring) + assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + +def test_extract_from_geometrycollection(): + geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])]) + expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])] + results = extract_points_from_geometry(geom_collection) + assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + +def test_extract_from_empty_multipoint(): + empty_multipoint = MultiPoint() + expected = [] + assert extract_points_from_geometry(empty_multipoint) == expected + +def test_extract_from_empty_linestring(): + empty_linestring = LineString() + expected = [] + assert extract_points_from_geometry(empty_linestring) == expected + +def test_extract_from_unsupported_type(): + with pytest.raises(NameError): + extract_points_from_geometry(Polygon([(0, 0), (1, 1), (1, 0)])) # Polygon is unsupported + +def test_extract_from_empty_geometrycollection(): + empty_geom_collection = GeometryCollection() + expected = [] + assert extract_points_from_geometry(empty_geom_collection) == expected \ No newline at end of file From 93741fcecaf0c2c6dd478b10dde0ed05ff6d8976 Mon Sep 17 00:00:00 2001 From: eberrigan Date: Tue, 23 Apr 2024 10:15:05 -0700 Subject: [PATCH 2/4] lint --- sleap_roots/convhull.py | 9 +++++++-- sleap_roots/points.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py index 12b3df2..88ad97c 100644 --- a/sleap_roots/convhull.py +++ b/sleap_roots/convhull.py @@ -4,7 +4,10 @@ from scipy.spatial import ConvexHull from scipy.spatial.distance import pdist from typing import Tuple, Optional, Union -from sleap_roots.points import extract_points_from_geometry, get_line_equation_from_points +from sleap_roots.points import ( + extract_points_from_geometry, + get_line_equation_from_points, +) from shapely import box, LineString, normalize, Polygon @@ -449,7 +452,9 @@ def get_chull_intersection_vectors( ValueError: If pts does not have the expected shape. """ if r0_pts.ndim == 1 or rn_pts.ndim == 1 or pts.ndim == 2: - print("Not enough instances or incorrect format to compute convex hull intersections.") + print( + "Not enough instances or incorrect format to compute convex hull intersections." + ) return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) # Check for valid pts input diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 22bdfcf..fd9acbc 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -10,21 +10,21 @@ def extract_points_from_geometry(geometry): """Extracts coordinates as a list of numpy arrays from any given Shapely geometry object. - - This function supports Point, MultiPoint, LineString, and GeometryCollection types. - It recursively extracts coordinates from complex geometries and aggregates them into a single list. + + This function supports Point, MultiPoint, LineString, and GeometryCollection types. + It recursively extracts coordinates from complex geometries and aggregates them into a single list. For unsupported geometry types, it returns an empty list. - + Parameters: - geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points. - + Returns: - - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. + - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. The list will be empty if the geometry type is unsupported or contains no coordinates. - + Raises: - TypeError: If the input is not a recognized Shapely geometry type. - + Example: >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection >>> point = Point(1, 2) From 5033de5979837741bc76f816ae99d04c13bdfae5 Mon Sep 17 00:00:00 2001 From: eberrigan Date: Tue, 23 Apr 2024 10:52:16 -0700 Subject: [PATCH 3/4] Black --- tests/test_points.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/test_points.py b/tests/test_points.py index ed042a8..54c37d6 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -3,7 +3,12 @@ from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection from sleap_roots import Series from sleap_roots.lengths import get_max_length_pts -from sleap_roots.points import extract_points_from_geometry, filter_plants_with_unexpected_ct, get_count, join_pts +from sleap_roots.points import ( + extract_points_from_geometry, + filter_plants_with_unexpected_ct, + get_count, + join_pts, +) from sleap_roots.points import ( get_all_pts_array, get_nodes, @@ -745,39 +750,48 @@ def test_extract_from_point(): expected = [np.array([1, 2])] assert np.array_equal(extract_points_from_geometry(point), expected) + def test_extract_from_multipoint(): multipoint = MultiPoint([(1, 2), (3, 4)]) expected = [np.array([1, 2]), np.array([3, 4])] results = extract_points_from_geometry(multipoint) assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + def test_extract_from_linestring(): linestring = LineString([(0, 0), (1, 1), (2, 2)]) expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])] results = extract_points_from_geometry(linestring) assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + def test_extract_from_geometrycollection(): geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])]) expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])] results = extract_points_from_geometry(geom_collection) assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + def test_extract_from_empty_multipoint(): empty_multipoint = MultiPoint() expected = [] assert extract_points_from_geometry(empty_multipoint) == expected + def test_extract_from_empty_linestring(): empty_linestring = LineString() expected = [] assert extract_points_from_geometry(empty_linestring) == expected + def test_extract_from_unsupported_type(): with pytest.raises(NameError): - extract_points_from_geometry(Polygon([(0, 0), (1, 1), (1, 0)])) # Polygon is unsupported + extract_points_from_geometry( + Polygon([(0, 0), (1, 1), (1, 0)]) + ) # Polygon is unsupported + def test_extract_from_empty_geometrycollection(): empty_geom_collection = GeometryCollection() expected = [] - assert extract_points_from_geometry(empty_geom_collection) == expected \ No newline at end of file + assert extract_points_from_geometry(empty_geom_collection) == expected From 5a51eab3f530fd242eeb5d18a124dd9457b8f411 Mon Sep 17 00:00:00 2001 From: eberrigan Date: Tue, 23 Apr 2024 10:58:53 -0700 Subject: [PATCH 4/4] pydoc style --- sleap_roots/points.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sleap_roots/points.py b/sleap_roots/points.py index fd9acbc..6d5c5c1 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -15,15 +15,12 @@ def extract_points_from_geometry(geometry): It recursively extracts coordinates from complex geometries and aggregates them into a single list. For unsupported geometry types, it returns an empty list. - Parameters: - - geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points. + Args: + geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points. Returns: - - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. - The list will be empty if the geometry type is unsupported or contains no coordinates. - - Raises: - - TypeError: If the input is not a recognized Shapely geometry type. + List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. + The list will be empty if the geometry type is unsupported or contains no coordinates. Example: >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection