Skip to content

Commit

Permalink
Fix new older monocot traits for edge cases (#77)
Browse files Browse the repository at this point in the history
* add helper function for getting points from geometry

* lint

* Black

* pydoc style
  • Loading branch information
eberrigan authored Apr 23, 2024
1 parent 035b153 commit deaf0d4
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 12 deletions.
24 changes: 16 additions & 8 deletions sleap_roots/convhull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from scipy.spatial import ConvexHull
from scipy.spatial.distance import pdist
from typing import Tuple, Optional, Union
from sleap_roots.points import get_line_equation_from_points
from sleap_roots.points import (
extract_points_from_geometry,
get_line_equation_from_points,
)
from shapely import box, LineString, normalize, Polygon


Expand Down Expand Up @@ -382,13 +385,9 @@ def get_chull_areas_via_intersection(
# Find the intersection between the hull perimeter and the extended line
intersection = extended_line.intersection(hull_perimeter)

# Add intersection points to both lists
# Compute the intersection points and add to lists
if not intersection.is_empty:
intersect_points = (
np.array([[point.x, point.y] for point in intersection.geoms])
if intersection.geom_type == "MultiPoint"
else np.array([[intersection.x, intersection.y]])
)
intersect_points = extract_points_from_geometry(intersection)
above_line.extend(intersect_points)
below_line.extend(intersect_points)

Expand Down Expand Up @@ -452,6 +451,12 @@ def get_chull_intersection_vectors(
Raises:
ValueError: If pts does not have the expected shape.
"""
if r0_pts.ndim == 1 or rn_pts.ndim == 1 or pts.ndim == 2:
print(
"Not enough instances or incorrect format to compute convex hull intersections."
)
return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]]))

# Check for valid pts input
if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2:
raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).")
Expand All @@ -460,7 +465,7 @@ def get_chull_intersection_vectors(
raise ValueError("rn_pts must be a numpy array of shape (instances, 2).")
# Ensure r0_pts is a numpy array of shape (instances, 2)
if not isinstance(r0_pts, np.ndarray) or r0_pts.ndim != 2 or r0_pts.shape[-1] != 2:
raise ValueError("r0_pts must be a numpy array of shape (instances, 2).")
raise ValueError(f"r0_pts must be a numpy array of shape (instances, 2).")

# Flatten pts to 2D array and remove NaN values
flattened_pts = pts.reshape(-1, 2)
Expand All @@ -481,6 +486,9 @@ def get_chull_intersection_vectors(

# Ensuring r0_pts does not contain NaN values
r0_pts_valid = r0_pts[~np.isnan(r0_pts).any(axis=1)]
# Expect two vectors in the end
if len(r0_pts_valid) < 2:
return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]]))

# Get the vertices of the convex hull
hull_vertices = hull.points[hull.vertices]
Expand Down
40 changes: 39 additions & 1 deletion sleap_roots/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,49 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from shapely.geometry import LineString
from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
from shapely.ops import nearest_points
from typing import List, Optional, Tuple


def extract_points_from_geometry(geometry):
"""Extracts coordinates as a list of numpy arrays from any given Shapely geometry object.
This function supports Point, MultiPoint, LineString, and GeometryCollection types.
It recursively extracts coordinates from complex geometries and aggregates them into a single list.
For unsupported geometry types, it returns an empty list.
Args:
geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points.
Returns:
List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point.
The list will be empty if the geometry type is unsupported or contains no coordinates.
Example:
>>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
>>> point = Point(1, 2)
>>> multipoint = MultiPoint([(1, 2), (3, 4)])
>>> linestring = LineString([(0, 0), (1, 1), (2, 2)])
>>> geom_col = GeometryCollection([point, multipoint, linestring])
>>> extract_points_from_geometry(geom_col)
[array([1, 2]), array([1, 2]), array([3, 4]), array([0, 0]), array([1, 1]), array([2, 2])]
"""
if isinstance(geometry, Point):
return [np.array([geometry.x, geometry.y])]
elif isinstance(geometry, MultiPoint):
return [np.array([point.x, point.y]) for point in geometry.geoms]
elif isinstance(geometry, LineString):
return [np.array([x, y]) for x, y in zip(*geometry.xy)]
elif isinstance(geometry, GeometryCollection):
points = []
for geom in geometry.geoms:
points.extend(extract_points_from_geometry(geom))
return points
else:
raise TypeError(f"Unsupported geometry type: {type(geometry).__name__}")


def get_count(pts: np.ndarray):
"""Get number of roots.
Expand Down
1 change: 0 additions & 1 deletion tests/test_convhull.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def test_basic_functionality(pts_shape_3_6_2):
@pytest.mark.parametrize(
"invalid_input",
[
(np.array([1, 2]), np.array([3, 4]), np.array([[[1, 2], [3, 4]]]), None),
(np.array([[1, 2, 3]]), np.array([[3, 4]]), np.array([[[1, 2], [3, 4]]]), None),
# Add more invalid inputs as needed
],
Expand Down
61 changes: 59 additions & 2 deletions tests/test_points.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import numpy as np
import pytest
from shapely.geometry import LineString
from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
from sleap_roots import Series
from sleap_roots.lengths import get_max_length_pts
from sleap_roots.points import filter_plants_with_unexpected_ct, get_count, join_pts
from sleap_roots.points import (
extract_points_from_geometry,
filter_plants_with_unexpected_ct,
get_count,
join_pts,
)
from sleap_roots.points import (
get_all_pts_array,
get_nodes,
Expand Down Expand Up @@ -738,3 +743,55 @@ def test_filter_plants_with_unexpected_ct_incorrect_input_types():
expected_count = "not a float"
with pytest.raises(ValueError):
filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count)


def test_extract_from_point():
point = Point(1, 2)
expected = [np.array([1, 2])]
assert np.array_equal(extract_points_from_geometry(point), expected)


def test_extract_from_multipoint():
multipoint = MultiPoint([(1, 2), (3, 4)])
expected = [np.array([1, 2]), np.array([3, 4])]
results = extract_points_from_geometry(multipoint)
assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))


def test_extract_from_linestring():
linestring = LineString([(0, 0), (1, 1), (2, 2)])
expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])]
results = extract_points_from_geometry(linestring)
assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))


def test_extract_from_geometrycollection():
geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])])
expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])]
results = extract_points_from_geometry(geom_collection)
assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))


def test_extract_from_empty_multipoint():
empty_multipoint = MultiPoint()
expected = []
assert extract_points_from_geometry(empty_multipoint) == expected


def test_extract_from_empty_linestring():
empty_linestring = LineString()
expected = []
assert extract_points_from_geometry(empty_linestring) == expected


def test_extract_from_unsupported_type():
with pytest.raises(NameError):
extract_points_from_geometry(
Polygon([(0, 0), (1, 1), (1, 0)])
) # Polygon is unsupported


def test_extract_from_empty_geometrycollection():
empty_geom_collection = GeometryCollection()
expected = []
assert extract_points_from_geometry(empty_geom_collection) == expected

0 comments on commit deaf0d4

Please sign in to comment.