Skip to content

Commit

Permalink
Implement Younger Monocot Pipeline (#64)
Browse files Browse the repository at this point in the history
* Start `YoungerMonocotPipeline`

* Add `main_grav_indices`

* Add younger monocot data

* Edit `YoungMonocotPipeline`

* Test `YoungerMonocotPipeline`

* Fix trait definitions

* Fix description of `main_grav_indices`

* Test gravitropism index

* Fixed test for `grav_index` for NaN values

* Add tests for `get_grav_index` for expected output shape
  • Loading branch information
eberrigan authored Sep 15, 2023
1 parent 0e7f824 commit 86fca8a
Show file tree
Hide file tree
Showing 7 changed files with 666 additions and 50 deletions.
84 changes: 38 additions & 46 deletions sleap_roots/lengths.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Get length-related traits."""
import numpy as np
from sleap_roots.bases import get_base_tip_dist
from typing import Optional
from typing import Union


def get_max_length_pts(pts: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -113,57 +112,50 @@ def get_root_lengths_max(pts: np.ndarray) -> np.ndarray:


def get_grav_index(
primary_length: Optional[float] = None,
primary_base_tip_dist: Optional[float] = None,
pts: Optional[np.ndarray] = None,
) -> float:
"""Calculate the gravitropism index of a primary root.
lengths: Union[float, np.ndarray], base_tip_dists: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
"""Calculate the gravitropism index of a root.
The gravitropism index quantifies the curviness of the root's growth. A higher
gravitropism index indicates a curvier root (less responsive to gravity), while a
lower index indicates a straighter root (more responsive to gravity). The index is
computed as the difference between the maximum primary root length and straight-line
distance from the base to the tip of the primary root, normalized by the root length.
computed as the difference between the maximum root length and straight-line
distance from the base to the tip of the root, normalized by the root length.
Args:
primary_length: Maximum length of the primary root. Used if `pts` is not
provided.
primary_base_tip_dist: The straight-line distance from the base to the tip of
the primary root. Used if `pts` is not provided.
pts: Landmarks of the primary root of shape `(instances, nodes, 2)`. If
provided, `primary_length` and `primary_base_tip_dist` are ignored.
lengths: Maximum length of the root(s). Can be a scalar or a 1D numpy array
of shape `(instances,)`.
base_tip_dists: The straight-line distance from the base to the tip of the
root(s). Can be a scalar or a 1D numpy array of shape `(instances,)`.
Returns:
float: Gravitropism index of the primary root, quantifying its curviness.
Gravitropism index of the root(s), quantifying its/their curviness. Will be a
scalar if input is scalar, or a 1D numpy array of shape `(instances,)`
otherwise.
"""
# Use provided scalar values if available
if primary_length is not None and primary_base_tip_dist is not None:
max_primary_length = primary_length
max_base_tip_distance = primary_base_tip_dist

# Use provided pts array to compute required values if available
elif pts is not None:
if np.isnan(pts).all():
return np.nan
primary_length_max = get_root_lengths_max(pts=pts)
primary_base_tip_dist = get_base_tip_dist(pts=pts)
max_primary_length = np.nanmax(primary_length_max)
max_base_tip_distance = np.nanmax(primary_base_tip_dist)

# Check if the input is scalar or array
is_scalar_input = np.isscalar(lengths) and np.isscalar(base_tip_dists)

# Convert scalars to numpy arrays for uniform handling
lengths = np.atleast_1d(np.asarray(lengths, dtype=float))
base_tip_dists = np.atleast_1d(np.asarray(base_tip_dists, dtype=float))

# Check for shape mismatch
if lengths.shape != base_tip_dists.shape:
raise ValueError("The shapes of lengths and base_tip_dists must match.")

# Calculate the gravitropism index where possible
grav_index = np.where(
(~np.isnan(lengths))
& (~np.isnan(base_tip_dists))
& (lengths > 0)
& (lengths >= base_tip_dists),
(lengths - base_tip_dists) / lengths,
np.nan,
)

# Return scalar or array based on the input type
if is_scalar_input:
return grav_index.item()
else:
raise ValueError(
"Either both primary_length and primary_base_tip_dist, or pts"
"must be provided."
)

# Check for invalid values (NaN or zero lengths)
if (
np.isnan(max_primary_length)
or np.isnan(max_base_tip_distance)
or max_primary_length == 0
):
return np.nan

# Calculate and return gravitropism index
grav_index = (max_primary_length - max_base_tip_distance) / max_primary_length
return grav_index
return grav_index
Loading

0 comments on commit 86fca8a

Please sign in to comment.