Skip to content

Commit

Permalink
Fix get_network_distribution (#67)
Browse files Browse the repository at this point in the history
* Added check to `get_network_distribution` for root length > 1

* Added tests for `get_network_distribution`
  • Loading branch information
eberrigan authored Sep 22, 2023
1 parent d82c7d9 commit 705bae8
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
13 changes: 7 additions & 6 deletions sleap_roots/networklength.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import numpy as np
from shapely import LineString, Polygon
from sleap_roots.lengths import get_root_lengths, get_max_length_pts
from typing import Optional, Tuple, Union
from sleap_roots.lengths import get_max_length_pts
from typing import Tuple, Union


def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]:
Expand Down Expand Up @@ -198,10 +198,11 @@ def get_network_distribution(
# Calculate length of roots within the lower bounding box
network_length = 0
for root in all_roots:
root_poly = LineString(root)
lower_intersection = root_poly.intersection(lower_box)
root_length = lower_intersection.length
network_length += root_length if ~np.isnan(root_length) else 0
if len(root) > 1: # Ensure that root has more than one point
root_poly = LineString(root)
lower_intersection = root_poly.intersection(lower_box)
root_length = lower_intersection.length
network_length += root_length if ~np.isnan(root_length) else 0

return network_length

Expand Down
98 changes: 98 additions & 0 deletions tests/test_networklength.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
from shapely import LineString, Polygon
from sleap_roots import Series
from sleap_roots.convhull import get_chull_area, get_convhull
from sleap_roots.lengths import get_max_length_pts, get_root_lengths
Expand Down Expand Up @@ -169,6 +170,103 @@ def test_get_network_solidity_rice(rice_h5):
np.testing.assert_almost_equal(ratio, 0.03366254601775008, decimal=7)


def test_get_network_distribution_one_point():
# Define inputs
primary_pts = np.array([[[1, 1], [2, 2], [3, 3]]])
lateral_pts = np.array(
[[[4, 4], [5, 5]], [[6, 6], [np.nan, np.nan]]]
) # One of the roots has only one point
bounding_box = (0, 0, 10, 10)
fraction = 2 / 3
monocots = False

# Call the function
network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction, monocots
)

# Define the expected result
# Only the valid roots should be considered in the calculation
lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[0]).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

# Assert that the result is as expected
assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_empty_arrays():
primary_pts = np.full((2, 2), np.nan)
lateral_pts = np.full((2, 2, 2), np.nan)
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)
assert network_length == 0


def test_get_network_distribution_with_nans():
primary_pts = np.array([[1, 1], [2, 2], [np.nan, np.nan]])
lateral_pts = np.array([[[4, 4], [5, 5], [np.nan, np.nan]]])
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[:-1]).intersection(lower_box).length
+ LineString(lateral_pts[0, :-1]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_monocots():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
monocots = True

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, monocots=monocots
)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(lateral_pts[0]).intersection(lower_box).length
) # Only lateral_pts are considered

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_different_fraction():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
fraction = 0.5

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction=fraction
)

lower_box = Polygon(
[(0, 10 - 10 * fraction), (0, 10), (10, 10), (10, 10 - 10 * fraction)]
)
expected_length = (
LineString(primary_pts).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution(canola_h5):
series = Series.load(
canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"
Expand Down

0 comments on commit 705bae8

Please sign in to comment.