diff --git a/sleap_roots/networklength.py b/sleap_roots/networklength.py index c089ec3..90c0688 100644 --- a/sleap_roots/networklength.py +++ b/sleap_roots/networklength.py @@ -2,8 +2,8 @@ import numpy as np from shapely import LineString, Polygon -from sleap_roots.lengths import get_root_lengths, get_max_length_pts -from typing import Optional, Tuple, Union +from sleap_roots.lengths import get_max_length_pts +from typing import Tuple, Union def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: @@ -198,10 +198,11 @@ def get_network_distribution( # Calculate length of roots within the lower bounding box network_length = 0 for root in all_roots: - root_poly = LineString(root) - lower_intersection = root_poly.intersection(lower_box) - root_length = lower_intersection.length - network_length += root_length if ~np.isnan(root_length) else 0 + if len(root) > 1: # Ensure that root has more than one point + root_poly = LineString(root) + lower_intersection = root_poly.intersection(lower_box) + root_length = lower_intersection.length + network_length += root_length if ~np.isnan(root_length) else 0 return network_length diff --git a/tests/test_networklength.py b/tests/test_networklength.py index 3126ceb..f244cec 100644 --- a/tests/test_networklength.py +++ b/tests/test_networklength.py @@ -1,5 +1,6 @@ import pytest import numpy as np +from shapely import LineString, Polygon from sleap_roots import Series from sleap_roots.convhull import get_chull_area, get_convhull from sleap_roots.lengths import get_max_length_pts, get_root_lengths @@ -169,6 +170,103 @@ def test_get_network_solidity_rice(rice_h5): np.testing.assert_almost_equal(ratio, 0.03366254601775008, decimal=7) +def test_get_network_distribution_one_point(): + # Define inputs + primary_pts = np.array([[[1, 1], [2, 2], [3, 3]]]) + lateral_pts = np.array( + [[[4, 4], [5, 5]], [[6, 6], [np.nan, np.nan]]] + ) # One of the roots has only one point + bounding_box = (0, 0, 10, 10) + fraction = 2 / 3 + monocots = False + + # Call the function + network_length = get_network_distribution( + primary_pts, lateral_pts, bounding_box, fraction, monocots + ) + + # Define the expected result + # Only the valid roots should be considered in the calculation + lower_box = Polygon( + [(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))] + ) + expected_length = ( + LineString(primary_pts[0]).intersection(lower_box).length + + LineString(lateral_pts[0]).intersection(lower_box).length + ) + + # Assert that the result is as expected + assert network_length == pytest.approx(expected_length) + + +def test_get_network_distribution_empty_arrays(): + primary_pts = np.full((2, 2), np.nan) + lateral_pts = np.full((2, 2, 2), np.nan) + bounding_box = (0, 0, 10, 10) + + network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box) + assert network_length == 0 + + +def test_get_network_distribution_with_nans(): + primary_pts = np.array([[1, 1], [2, 2], [np.nan, np.nan]]) + lateral_pts = np.array([[[4, 4], [5, 5], [np.nan, np.nan]]]) + bounding_box = (0, 0, 10, 10) + + network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box) + + lower_box = Polygon( + [(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))] + ) + expected_length = ( + LineString(primary_pts[:-1]).intersection(lower_box).length + + LineString(lateral_pts[0, :-1]).intersection(lower_box).length + ) + + assert network_length == pytest.approx(expected_length) + + +def test_get_network_distribution_monocots(): + primary_pts = np.array([[1, 1], [2, 2], [3, 3]]) + lateral_pts = np.array([[[4, 4], [5, 5]]]) + bounding_box = (0, 0, 10, 10) + monocots = True + + network_length = get_network_distribution( + primary_pts, lateral_pts, bounding_box, monocots=monocots + ) + + lower_box = Polygon( + [(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))] + ) + expected_length = ( + LineString(lateral_pts[0]).intersection(lower_box).length + ) # Only lateral_pts are considered + + assert network_length == pytest.approx(expected_length) + + +def test_get_network_distribution_different_fraction(): + primary_pts = np.array([[1, 1], [2, 2], [3, 3]]) + lateral_pts = np.array([[[4, 4], [5, 5]]]) + bounding_box = (0, 0, 10, 10) + fraction = 0.5 + + network_length = get_network_distribution( + primary_pts, lateral_pts, bounding_box, fraction=fraction + ) + + lower_box = Polygon( + [(0, 10 - 10 * fraction), (0, 10), (10, 10), (10, 10 - 10 * fraction)] + ) + expected_length = ( + LineString(primary_pts).intersection(lower_box).length + + LineString(lateral_pts[0]).intersection(lower_box).length + ) + + assert network_length == pytest.approx(expected_length) + + def test_get_network_distribution(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"