diff --git a/sleap_roots/lengths.py b/sleap_roots/lengths.py index a7938b6..e4af92d 100644 --- a/sleap_roots/lengths.py +++ b/sleap_roots/lengths.py @@ -125,7 +125,7 @@ def get_curve_index( & (~np.isnan(base_tip_dists)) & (lengths > 0) & (lengths >= base_tip_dists), - (lengths - base_tip_dists) / lengths, + (lengths - base_tip_dists) / np.where(lengths != 0, lengths, np.nan), np.nan, ) diff --git a/tests/test_lengths.py b/tests/test_lengths.py index 263f805..004fd61 100644 --- a/tests/test_lengths.py +++ b/tests/test_lengths.py @@ -288,7 +288,6 @@ def test_invalid_scalar_values(): assert np.isnan(get_curve_index(0, 8)) -# tests for `get_root_lengths` def test_curve_index_float(): assert get_curve_index(10.0, 5.0) == 0.5