diff --git a/sleap_roots/__init__.py b/sleap_roots/__init__.py index 3f0c719..ac2cbf9 100644 --- a/sleap_roots/__init__.py +++ b/sleap_roots/__init__.py @@ -10,9 +10,8 @@ import sleap_roots.scanline import sleap_roots.series import sleap_roots.summary -import sleap_roots.traitsgraph -import sleap_roots.graphpipeline -from sleap_roots.graphpipeline import get_all_plants_traits +import sleap_roots.trait_pipelines +from sleap_roots.trait_pipelines import DicotPipeline, TraitDef from sleap_roots.series import Series # Define package version. diff --git a/sleap_roots/angle.py b/sleap_roots/angle.py index b50ed10..c8784e5 100644 --- a/sleap_roots/angle.py +++ b/sleap_roots/angle.py @@ -4,51 +4,112 @@ import math -def get_node_ind(pts: np.ndarray, proximal=True) -> np.ndarray: - """Find nproximal/distal node index. +def get_node_ind(pts: np.ndarray, proximal: bool = True) -> np.ndarray: + """Find proximal/distal node index. Args: - pts: Numpy array of points of shape (instances, nodes, 2). + pts: Numpy array of points of shape (instances, nodes, 2) or (nodes, 2). proximal: Boolean value, where true is proximal (default), false is distal. Returns: An array of shape (instances,) of proximal or distal node index. """ - node_ind = [] - for i in range(pts.shape[0]): - ind = 1 if proximal else pts.shape[1] - 1 # set initial proximal/distal node - while np.isnan(pts[i, ind]).any(): - ind += 1 if proximal else -1 - if (ind == pts.shape[1] and proximal) or (ind == 0 and not proximal): - break - node_ind.append(ind) + # Check if pts is a numpy array + if not isinstance(pts, np.ndarray): + raise TypeError("Input pts should be a numpy array.") + + # Check if pts has 2 or 3 dimensions + if pts.ndim not in [2, 3]: + raise ValueError("Input pts should have 2 or 3 dimensions.") + + # Check if the last dimension of pts has size 2 + if pts.shape[-1] != 2: + raise ValueError( + "The last dimension of the input pts should have size 2," + "representing x and y coordinates." + ) + + # Check if pts is 2D, if so, reshape to 3D + if pts.ndim == 2: + pts = pts[np.newaxis, ...] + + # Identify where NaN values exist + nan_mask = np.isnan(pts).any(axis=-1) + + # If only NaN values, return NaN + if nan_mask.all(): + return np.nan + + if proximal: + # For proximal, we want the first non-NaN node in the first half root + # get the first half nan mask (exclude the base node) + node_proximal = nan_mask[:, 1 : int((nan_mask.shape[1] + 1) / 2)] + # get the nearest non-Nan node index + node_ind = np.argmax(~node_proximal, axis=-1) + # if there is no non-Nan node, set value of 99 + node_ind[node_proximal.all(axis=1)] = 99 + node_ind = node_ind + 1 # adjust indices by adding one (base node) + else: + # For distal, we want the last non-NaN node in the last half root + # get the last half nan mask + node_distal = nan_mask[:, int(nan_mask.shape[1] / 2) :] + # get the farest non-Nan node + node_ind = (node_distal[:, ::-1] == False).argmax(axis=1) + node_ind[node_distal.all(axis=1)] = -95 # set value if no non-Nan node + node_ind = pts.shape[1] - node_ind - 1 # adjust indices by reversing + + # reset indices of 0 (base node) if no non-Nan node + node_ind[node_ind == 100] = 0 + + # If pts was originally 2D, return a scalar instead of a single-element array + if pts.shape[0] == 1: + return node_ind[0] + + # If only one root, return a scalar instead of a single-element array + if node_ind.shape[0] == 1: + return node_ind[0] + return node_ind -def get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray: +def get_root_angle( + pts: np.ndarray, node_ind: np.ndarray, proximal: bool = True, base_ind=0 +) -> np.ndarray: """Find angles for each root. Args: pts: Numpy array of points of shape (instances, nodes, 2). + node_ind: Primary or lateral root node index. proximal: Boolean value, where true is proximal (default), false is distal. base_ind: Index of base node in the skeleton (default: 0). Returns: An array of shape (instances,) of angles in degrees, modulo 360. """ - node_ind = get_node_ind(pts, proximal) # get proximal or distal node index + # if node_ind is a single int value, make it as array to keep consistent + if not isinstance(node_ind, np.ndarray): + node_ind = [node_ind] + + if np.isnan(node_ind).all(): + return np.nan + + if pts.ndim == 2: + pts = np.expand_dims(pts, axis=0) + angs_root = [] for i in range(len(node_ind)): - # filter out the cases if all nan nodes in last/first half part - # to calculate proximal/distal angle - if (node_ind[i] < math.ceil(pts.shape[1] / 2) and proximal) or ( - node_ind[i] >= math.floor(pts.shape[1] / 2) and not (proximal) - ): + # if the node_ind is 0, do NOT calculate angs + if node_ind[i] == 0: + angs = np.nan + else: xy = pts[i, node_ind[i], :] - pts[i, base_ind, :] # center on base node # calculate the angle and convert to the start with gravity direction ang = np.arctan2(-xy[1], xy[0]) * 180 / np.pi angs = abs(ang + 90) if ang < 90 else abs(-(360 - 90 - ang)) - else: - angs = np.nan angs_root.append(angs) - return np.array(angs_root) + angs_root = np.array(angs_root) + + # If only one root, return a scalar instead of a single-element array + if angs_root.shape[0] == 1: + return angs_root[0] + return angs_root diff --git a/sleap_roots/bases.py b/sleap_roots/bases.py index a834b97..70663a2 100644 --- a/sleap_roots/bases.py +++ b/sleap_roots/bases.py @@ -1,110 +1,82 @@ """Trait calculations that rely on bases (i.e., dicot-only).""" import numpy as np -import shapely from shapely.geometry import LineString, Point from shapely.ops import nearest_points +from typing import Union def get_bases(pts: np.ndarray, monocots: bool = False) -> np.ndarray: - """Return bases (r1) from each lateral root. + """Return bases (r1) from each root. Args: - pts: Root landmarks as array of shape (instances, nodes, 2) + pts: Root landmarks as array of shape `(instances, nodes, 2)` or `(nodes, 2)`. monocots: Boolean value, where false is dicot (default), true is rice. Returns: - Array of bases (instances, (x, y)). + Array of bases `(instances, (x, y))`. If the input is `(nodes, 2)`, an array of + shape `(2,)` will be returned. """ - # Get the first point of each instance. Shape is (instances, 2) if monocots: return np.nan - else: - base_pts = pts[:, 0] - return base_pts + # If the input has shape `(nodes, 2)`, reshape it for consistency + if pts.ndim == 2: + pts = pts[np.newaxis, ...] -def get_root_lengths(pts: np.ndarray) -> np.ndarray: - """Return root lengths for all roots in a frame. + # Get the first point of each instance + base_pts = pts[:, 0] # Shape is `(instances, 2)` - Args: - pts: Root landmarks as array of shape (instances, nodes, 2). + # If the input was `(nodes, 2)`, return an array of shape `(2,)` instead of `(1, 2)` + if base_pts.shape[0] == 1: + return base_pts[0] - Returns: - Array of root lengths of shape (instances,). - If there is no root, or the roots is one point only (all of the rest of the - points are NaNs), an array of NaNs with shape (len(pts),) is returned. - This is also the case for non-contiguous points at the moment. - """ - # Get the (x,y) differences of segments for each instance. - segment_diffs = np.diff(pts, axis=1) - # Get the lengths of each segment by taking the norm. - segment_lengths = np.linalg.norm(segment_diffs, axis=-1) - # Add the segments together to get the total length using nansum. - total_lengths = np.nansum(segment_lengths, axis=-1) - # Find the NaN segment lengths and record NaN in place of 0 when finding the total - # length. - total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan - return total_lengths + return base_pts -def get_root_lengths_max(lengths: np.ndarray) -> np.ndarray: - """Return maximum root length for all roots in a frame. +def get_base_tip_dist( + base_pts: np.ndarray, tip_pts: np.ndarray +) -> Union[np.ndarray, float]: + """Calculate the straight-line distance(s) from the base(s) to the tip(s). Args: - lengths: root lengths with shape of (instances,). + base_pts: The x and y coordinates of the base point(s) of the root(s). Shape can + be either `(2,)` for a single point or `(instances, 2)` for multiple + instances. + tip_pts: The x and y coordinates of the tip point(s) of the root(s). Shape + should match that of `base_pts`. Returns: - Scalar of the maximum root length. + Distance(s) from the base(s) to the tip(s) of the root(s). If there's only one + distance (i.e., shape is `(1,)`), a scalar is returned. Otherwise, an array + matching the first dimension of the input arrays is returned. """ - max_length = np.nanmax(lengths) - return max_length - + # Check if the shapes of the two input arrays match + if base_pts.shape != tip_pts.shape: + raise ValueError("The shapes of base_pts and tip_pts must match.") -def get_base_tip_dist(pts: np.ndarray) -> np.ndarray: - """Return distance from root base to tip. - - Args: - pts: Root landmarks as array of shape (instances, nodes, 2) - - Returns: - Array of distances from base to tip of shape (instances,). - """ - base_pt = pts[:, 0] - tip_pt = pts[:, -1] - distance = np.linalg.norm(base_pt - tip_pt, axis=-1) - return distance + # Compute the Euclidean distance(s) between the point(s) + distances = np.linalg.norm(base_pts - tip_pts, axis=-1) + # If distances is a scalar, check if either base_pts or tip_pts is NaN, and + # return NaN if true + if np.isscalar(distances): + if np.isnan(base_pts).any() or np.isnan(tip_pts).any(): + return np.nan + return distances -def get_grav_index(pts: np.ndarray): - """Get gravitropism index based on primary_length_max and primary_base_tip_dist. - - Args: - pts: primary root landmarks as array of shape (1, node, 2) - - Returns: - Scalar of primary root gravity index. - """ - # get primary root length, if predicted >1 primary roots, use the longest one - primary_length = get_root_lengths(pts) - primary_length_max = get_root_lengths_max(primary_length) - - # get the distance between base and tip in y axis - primary_base_tip_dist = get_base_tip_dist(pts) + # If distances is an array, create and apply the nan_mask + nan_mask = np.isnan(base_pts).any(axis=-1) | np.isnan(tip_pts).any(axis=-1) + distances[nan_mask] = np.nan - # calculate gravitropism index - pl_max = np.nanmax(primary_length_max) - if pl_max == 0: - return np.nan - grav_index = (pl_max - np.nanmax(primary_base_tip_dist)) / pl_max - return grav_index + return distances def get_lateral_count(pts: np.ndarray): """Get number of lateral roots. Args: - pts: lateral root landmarks as array of shape (instance, node, 2) + pts: lateral root landmarks as array of shape `(instance, node, 2)`. Return: Scalar of number of lateral roots. @@ -114,233 +86,276 @@ def get_lateral_count(pts: np.ndarray): def get_base_xs(pts: np.ndarray, monocots: bool = False) -> np.ndarray: - """Get x coordinations of base points. + """Get x coordinates of the base of each lateral root. Args: - pts: root landmarks as array of shape (instance, point, 2) + pts: root landmarks as array of shape `(instances, point, 2)` or bases + `(instances, 2)`. monocots: Boolean value, where false is dicot (default), true is rice. Return: - An array of bases in x axis (instance,). + An array of the x-coordinates of bases `(instance,)`. """ - _base_pts = get_bases(pts, monocots) + # If the input is a single number (float or integer), return np.nan + if isinstance(pts, (np.floating, float, np.integer, int)): + return np.nan + + # If the input array doesn't have 2 or 3 dimensions, raise an error + if pts.ndim not in (2, 3): + raise ValueError( + "Input array must be 2-dimensional (n_bases, 2) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + # If the input array has 3 dimensions, calculate the base points, + # otherwise, assume the input array already contains the base points + if pts.ndim == 3: + _base_pts = get_bases( + pts, monocots + ) # Assuming get_bases returns an array of shape (instance, 2) + else: + _base_pts = pts + + # If _base_pts is a single number (float or integer), return np.nan if isinstance(_base_pts, (np.floating, float, np.integer, int)): return np.nan + + # If the base points array doesn't have exactly 2 dimensions or + # the second dimension is not of size 2, raise an error + if _base_pts.ndim != 2 or _base_pts.shape[1] != 2: + raise ValueError( + "Array of base points must be 2-dimensional with shape (instance, 2)." + ) + + # If everything is fine, extract and return the x-coordinates of the base points else: base_xs = _base_pts[:, 0] return base_xs -def get_base_ys(pts: np.ndarray, monocots: bool = False) -> np.ndarray: - """Get y coordinations of base points. +def get_base_ys(base_pts: np.ndarray, monocots: bool = False) -> np.ndarray: + """Get y coordinates of the base of each root. Args: - pts: root landmarks as array of shape (instance, point, 2) + base_pts: root bases as array of shape `(instances, 2)` or `(2)` + when there is only one root, as is the case for primary roots. monocots: Boolean value, where false is dicot (default), true is rice. Return: - An array of bases in y axis (instance,). + An array of the y-coordinates of bases (instances,). """ - _base_pts = get_bases(pts, monocots) - if isinstance(_base_pts, (np.floating, float, np.integer, int)): + # If the input is a single number (float or integer), return np.nan + if isinstance(base_pts, (np.floating, float, np.integer, int)): return np.nan - else: - base_ys = _base_pts[:, 1] - return base_ys + + # Check for the 2D shape of the input array + if base_pts.ndim == 1: + # If shape is `(2,)`, then reshape it to `(1, 2)` for consistency + base_pts = base_pts.reshape(1, 2) + elif base_pts.ndim != 2: + raise ValueError("Input array must be of shape `(instances, 2)` or `(2, )`.") + + # At this point, `base_pts` should be of shape `(instances, 2)`. + base_ys = base_pts[:, 1] + return base_ys -def get_base_length(pts: np.ndarray, monocots: bool = False): - """Get lateral roots top and deepest bases distance in y axis. +def get_base_length(lateral_base_ys: np.ndarray, monocots: bool = False) -> float: + """Get the y-axis difference from the top lateral base to the bottom lateral base. Args: - pts: lateral root landmarks as array of shape (instance, point, 2) + lateral_base_ys: y-coordinates of the base points of lateral roots of shape + `(instances,)`. monocots: Boolean value, where false is dicot (default), true is rice. Return: - Top and deepest bases distance y-axis. + The distance between the top base y-coordinate and the deepest + base y-coordinate. """ - base_ys = get_base_ys(pts, monocots) - base_length = np.nanmax(base_ys) - np.nanmin(base_ys) + # If the roots are monocots, return NaN + if monocots: + return np.nan + + # Compute the difference between the maximum and minimum y-coordinates + base_length = np.nanmax(lateral_base_ys) - np.nanmin(lateral_base_ys) + return base_length -def get_base_ct_density(primary_pts, lateral_pts, monocots: bool = False): - """Get number of base points to maximum primary root length. +def get_base_ct_density( + primary_length_max: float, lateral_base_pts: np.ndarray, monocots: bool = False +): + """Get a ratio of the number of base points to maximum primary root length. Args: - primary_pts: primary root points - lateral_pts: lateral root points + primary_length_max: Scalar of maximum primary root length. + lateral_base_pts: Base points of lateral roots of shape (instances, 2). monocots: Boolean value, where false is dicot (default), true is rice. Return: Scalar of base count density. """ - # get number of base points of lateral roots - _base_pts = get_bases(lateral_pts, monocots) - if isinstance(_base_pts, (np.floating, float, np.integer, int)): + # Check if the input is valid for lateral_base_pts or if monocots is True + if ( + monocots + or isinstance(lateral_base_pts, (np.floating, float, np.integer, int)) + or np.isnan(lateral_base_pts).all() + ): return np.nan - else: - base_ct = len(_base_pts[~np.isnan(_base_pts[:, 0])]) - # get primary root length - lengths_primary = get_root_lengths(primary_pts) - base_ct_density = base_ct / np.nanmax(lengths_primary) - return base_ct_density + # Get the number of base points of lateral roots + base_ct = len(lateral_base_pts[~np.isnan(lateral_base_pts[:, 0])]) -def get_primary_depth(primary_pts): - """Get primary root tip depth. + # Handle cases where maximum primary length is zero or NaN to avoid division by zero + if primary_length_max == 0 or np.isnan(primary_length_max): + return np.nan - Args: - primary_pts: primary root points. + # Calculate base_ct_density + base_ct_density = base_ct / primary_length_max - Return: - Scalar of primary root tip depth. - """ - primary_depth = np.nanmax(primary_pts[:, :, 1]) - return primary_depth + return base_ct_density -def get_base_length_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray): - """Get ratio of top-deep base length to primary root length. +def get_base_length_ratio( + primary_length: float, base_length: float, monocots: bool = False +) -> float: + """Calculate the ratio of the length of the bases to the primary root length. Args: - primary_pts: primary root points. - lateral_pts: lateral root points. + primary_length (float): Length of the primary root. + base_length (float): Length of the bases along the primary root. + monocots (bool): True if the roots are monocots, False if they are dicots. - Return: - Scalar of base length ratio. + Returns: + Ratio of the length of the bases along the primary root to the primary root + length. """ - base_length = get_base_length(lateral_pts) - primary_length = get_root_lengths(primary_pts) - primary_length_max = get_root_lengths_max(primary_length) - if primary_length_max == 0: + # If roots are monocots or either of the lengths are NaN, return NaN + if monocots or np.isnan(primary_length) or np.isnan(base_length): return np.nan - else: - base_length_ratio = base_length / primary_length_max - return base_length_ratio + # Handle case where primary length is zero to avoid division by zero + if primary_length == 0: + return np.nan -def get_base_median_ratio( - primary_pts: np.ndarray, lateral_pts: np.ndarray, monocots: bool = False -): + # Compute and return the base length ratio + base_length_ratio = base_length / primary_length + return base_length_ratio + + +def get_base_median_ratio(lateral_base_ys, primary_tip_pt_y, monocots: bool = False): """Get ratio of median value in all base points to tip of primary root in y axis. Args: - primary_pts: primary root points. - lateral_pts: lateral root points. + lateral_base_ys: y-coordinates of the base points of lateral roots of shape + `(instances,)`. + primary_tip_pt_y: y-coordinate of the tip point of the primary root of shape + `(1)`. monocots: Boolean value, where false is dicot (default), true is rice. Return: - Scalar of base median ratio. + Scalar of base median ratio. If all y-coordinates of the lateral root bases are + NaN, the function returns NaN. """ - _base_pts = get_bases(lateral_pts, monocots) - pr_tip_depth = np.nanmax(primary_pts[:, :, 1]) - if np.isnan(_base_pts).all(): + # Check if the roots are monocots, if so return NaN + if monocots: return np.nan - else: - base_median_ratio = np.nanmedian(_base_pts[:, 1]) / pr_tip_depth - return base_median_ratio + + # Check if all y-coordinates of lateral root bases are NaN, if so return NaN + if np.isnan(lateral_base_ys).all(): + return np.nan + + # Calculate the median of all y-coordinates of lateral root bases + median_base_y = np.nanmedian(lateral_base_ys) + + # If primary_tip_pt_y is an array of shape (1), extract the scalar value + if isinstance(primary_tip_pt_y, np.ndarray) and primary_tip_pt_y.shape == (1,): + primary_tip_pt_y = primary_tip_pt_y[0] + + # Compute the ratio of the median y-coordinate of lateral root bases to the + # y-coordinate of the primary root tip + base_median_ratio = median_base_y / primary_tip_pt_y + + return base_median_ratio def get_root_pair_widths_projections( - lateral_pts, primary_pts, tolerance, monocots: bool = False -): - """Return estimation of stem width using bases of lateral roots. + primary_max_length_pts: np.ndarray, + lateral_pts: np.ndarray, + tolerance: float, + monocots: bool = False, +) -> float: + """Return estimation of root width using bases of lateral roots. Args: - lateral_pts: Lateral roots as arrays of shape (n, nodes, 2). - primary_pts: longest primary root as arrays of shape (n, nodes, 2). - tolerance: difference in projection norm between the right and left side (~0.02). - monocots: Boolean value, where false is dicot (default), true is rice. + primary_max_length_pts: Longest primary root as an array of shape (nodes, 2). + lateral_pts: Lateral roots as an array of shape (n, nodes, 2). + tolerance: Difference in projection norm between the right and left side + (~0.02). + monocots: Boolean value, where False is dicot (default), True is rice. Returns: - A match_dists is the distance in pixels between the bases of matched - roots as a vector of size (n_matches,). + float: The distance in pixels between the bases of matched roots, or NaN + if no matches were found or all input points were NaN. + Raises: + ValueError: If the input arrays are of incorrect shape. """ - if monocots: + if primary_max_length_pts.ndim != 2 or lateral_pts.ndim != 3: + raise ValueError("Input arrays should be 2-dimensional and 3-dimensional") + + if ( + monocots + or np.isnan(primary_max_length_pts).all() + or np.isnan(lateral_pts).all() + ): return np.nan - else: - if np.isnan(primary_pts).all(): - return np.nan - else: - primary_pts_filtered = primary_pts[~np.isnan(primary_pts).any(axis=2)] - primary_line = LineString(primary_pts_filtered) - - # Make a line of the primary points - primary_line = LineString(primary_pts_filtered) - - # Filter by whether the base node is present. - has_base = ~np.isnan(lateral_pts[:, 0, 0]) - valid_inds = np.argwhere(has_base).squeeze() - lateral_pts = lateral_pts[has_base] - - # Find roots facing left based on whether the base x-coord - # is larger than the tip x-coord. - is_left = lateral_pts[:, 0, 0] > np.nanmin(lateral_pts[:, 1:, 0], axis=1) - - # Edge Case: Only found roots on one side. - if is_left.all() or (~is_left).all(): - return np.nan - - # Get left and right base points. - left_bases, right_bases = lateral_pts[is_left, 0], lateral_pts[~is_left, 0] - - # Find the nearest point to each right lateral base on the primary root line - nearest_primary_right = [ - nearest_points(primary_line, Point(right_base))[0] - for right_base in right_bases - ] - - # Find the nearest point to each left lateral base on the primary root line - nearest_primary_left = [ - nearest_points(primary_line, Point(left_base))[0] - for left_base in left_bases - ] - - # Returns the distance along the primary line of point in nearest_primary_right, normalized to the length of the object. - nearest_primary_norm_right = np.array( - [ - primary_line.project(pt, normalized=True) - for pt in nearest_primary_right - ] - ) - # Returns the distance along the primary line of point in nearest_primary_left, normalized to the length of the object. - nearest_primary_norm_left = np.array( - [ - primary_line.project(pt, normalized=True) - for pt in nearest_primary_left - ] - ) - - # get all possible differences in projections from all base pairs - projection_diffs = np.abs( - nearest_primary_norm_left.reshape(-1, 1) - - nearest_primary_norm_right.reshape(1, -1) - ) - - # shape is [# of valid base pairs, 2 [left right]] - indices = np.argwhere(projection_diffs <= tolerance) - - left_inds = indices[:, 0] - right_inds = indices[:, 1] - - # Find pairwise distances. (shape is (# of left bases, # of right bases)) - dists = np.linalg.norm( - np.expand_dims(left_bases, axis=1) - - np.expand_dims(right_bases, axis=0), - axis=-1, - ) - - # Pull out match distances. - match_dists = np.array([dists[l, r] for l, r in zip(left_inds, right_inds)]) - - # Convert matches to indices before splitting by side. - left_inds = np.argwhere(is_left).reshape(-1)[left_inds] - right_inds = np.argwhere(~is_left).reshape(-1)[right_inds] - - # Convert matches to indices before filtering. - left_inds = valid_inds[left_inds] - right_inds = valid_inds[right_inds] - - return match_dists + + primary_pts_filtered = primary_max_length_pts[ + ~np.isnan(primary_max_length_pts).any(axis=-1) + ] + primary_line = LineString(primary_pts_filtered) + + has_base = ~np.isnan(lateral_pts[:, 0, 0]) + valid_inds = np.argwhere(has_base).squeeze() + lateral_pts = lateral_pts[has_base] + + is_left = lateral_pts[:, 0, 0] > np.nanmin(lateral_pts[:, 1:, 0], axis=1) + + if is_left.all() or (~is_left).all(): + return np.nan + + left_bases, right_bases = lateral_pts[is_left, 0], lateral_pts[~is_left, 0] + + nearest_primary_right = [ + nearest_points(primary_line, Point(right_base))[0] for right_base in right_bases + ] + + nearest_primary_left = [ + nearest_points(primary_line, Point(left_base))[0] for left_base in left_bases + ] + + nearest_primary_norm_right = np.array( + [primary_line.project(pt, normalized=True) for pt in nearest_primary_right] + ) + + nearest_primary_norm_left = np.array( + [primary_line.project(pt, normalized=True) for pt in nearest_primary_left] + ) + + projection_diffs = np.abs( + nearest_primary_norm_left.reshape(-1, 1) + - nearest_primary_norm_right.reshape(1, -1) + ) + + indices = np.argwhere(projection_diffs <= tolerance) + + left_inds = indices[:, 0] + right_inds = indices[:, 1] + + match_dists = np.linalg.norm( + left_bases[left_inds] - right_bases[right_inds], + axis=-1, + ) + + return match_dists diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py index 321227f..5ff3f92 100644 --- a/sleap_roots/convhull.py +++ b/sleap_roots/convhull.py @@ -1,185 +1,192 @@ """Convex hull fitting and derived trait calculation.""" import numpy as np -from scipy.spatial import ConvexHull, convex_hull_plot_2d +from scipy.spatial import ConvexHull from scipy.spatial.distance import pdist from typing import Tuple, Optional, Union def get_convhull(pts: np.ndarray) -> Optional[ConvexHull]: - """Get the convex hull for the points per frame. + """Compute the convex hull for the points per frame. Args: - pts: Root landmarks as array of shape (..., 2). + pts: Root landmarks as an array of shape (..., 2). Returns: - An object of convex hull. + An object representing the convex hull or None if a hull can't be formed. """ + # Ensure the input is an array of shape (n, 2) + if pts.ndim < 2 or pts.shape[-1] != 2: + raise ValueError("Input points should be of shape (..., 2).") + + # Reshape and filter out NaN values pts = pts.reshape(-1, 2) - pts = pts[~(np.isnan(pts).any(axis=-1))] + pts = pts[~np.isnan(pts).any(axis=-1)] - if len(pts) <= 2: + # Check for NaNs or infinite values + if np.isnan(pts).any() or np.isinf(pts).any(): return None - # Get convex hull - hull = ConvexHull(pts) + # Ensure there are at least 3 unique non-collinear points + if len(np.unique(pts, axis=0)) < 3: + return None - return hull + # Compute and return the convex hull + return ConvexHull(pts) -def get_convhull_features( - pts: Union[np.ndarray, ConvexHull] -) -> Tuple[float, float, float, float]: - """Get the convex hull features for the points per frame. +def get_chull_perimeter(hull: Union[np.ndarray, ConvexHull, None]) -> float: + """Calculate the perimeter of the convex hull formed by the given points. Args: - pts: Root landmarks as array of shape (..., 2). + hull: Either an array of landmark points, a pre-computed convex hull, or None. Returns: - A tuple of 4 convex hull features - perimeters, perimeter of the convex hull - areas, area of the convex hull - max_widths, maximum width of convex hull - max_heights, maximum height of convex hull - - If the convex hull fitting fails, NaNs are returned. + Scalar value representing the perimeter of the convex hull. Returns NaN if + unable to compute the convex hull or if the input is None. """ - hull = pts if type(pts) == ConvexHull else get_convhull(pts) + # If the input hull is None, return NaN + if hull is None: + return np.nan + + # If the input is an array, compute its convex hull + if isinstance(hull, np.ndarray): + hull = get_convhull(hull) + # If hull becomes None after attempting to compute the convex hull, return NaN if hull is None: - return np.full((4,), np.nan) + return np.nan - # perimeter - perimeter = hull.area - # area - area = hull.volume + # Ensure that the hull is of type ConvexHull + if not isinstance(hull, ConvexHull): + raise TypeError("After processing, the input must be a ConvexHull object.") - pts = pts.reshape(-1, 2) - pts = pts[~(np.isnan(pts).any(axis=-1))] - - # max 'width' - max_width = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]) - # max 'height' - max_height = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]) - - return ( - perimeter, - area, - max_width, - max_height, - ) + # Compute the perimeter of the convex hull + return hull.area -def get_chull_perimeter( - pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]] -): - """Get convex hull perimeter. +def get_chull_area(hull: Union[np.ndarray, ConvexHull]) -> float: + """Calculate the area of the convex hull formed by the given points. Args: - pts: landmark points, or convex hull, or tuple of convex hull results + hull: Either an array of landmark points or a pre-computed convex hull. - Return: - Scalar of convex hull perimeter. + Returns: + Scalar value representing the area of the convex hull. Returns NaN if unable + to compute the convex hull. """ - if type(pts) == tuple: - return pts[0] - elif type(pts) == ConvexHull: - hull = pts - else: - hull = get_convhull(pts) + # If the input hull is None, return NaN if hull is None: return np.nan - return hull.area + # If the input is an array, compute its convex hull + if isinstance(hull, np.ndarray): + hull = get_convhull(hull) -def get_chull_area( - pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]] -): - """Get convex hull area. + # If hull becomes None after attempting to compute the convex hull, return NaN + if hull is None: + return np.nan - Args: - pts: landmark points, or convex hull, or tuple of convex hull results + # Ensure that the hull is of type ConvexHull + if not isinstance(hull, ConvexHull): + raise TypeError("After processing, the input must be a ConvexHull object.") - Return: - Scalar of convex hull area. - """ - if type(pts) == tuple: - return pts[1] - elif type(pts) == ConvexHull: - hull = pts - else: - hull = get_convhull(pts) + # If hull couldn't be formed, return NaN if hull is None: return np.nan + + # Return the area of the convex hull return hull.volume -def get_chull_max_width( - pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]] -): - """Get maximum width of convex hull. +def get_chull_max_width(hull: Union[np.ndarray, ConvexHull]) -> float: + """Calculate the maximum width (in the x-axis direction) of the convex hull. Args: - pts: landmark points, or convex hull, or tuple of convex hull results + hull: Either an array of landmark points or a pre-computed convex hull. - Return: - Scalar of convex hull maximum width. + Returns: + Scalar value representing the maximum width of the convex hull. Returns NaN if + unable to compute the convex hull. """ - if type(pts) == tuple: - return pts[2] - elif type(pts) == ConvexHull: - hull = pts - else: - hull = get_convhull(pts) + # If hull is None, return NaN if hull is None: return np.nan - pts = pts.reshape(-1, 2) - pts = pts[~(np.isnan(pts).any(axis=-1))] - max_width = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]) + + # If the input is an array, compute its convex hull + if isinstance(hull, np.ndarray): + hull = get_convhull(hull) + if hull is None: + return np.nan + # Extract the convex hull points + hull_pts = hull.points[hull.vertices] + elif isinstance(hull, ConvexHull): + hull_pts = hull.points[hull.vertices] + else: + raise TypeError( + "Input must be either an array of points or a ConvexHull object." + ) + + # Calculate the maximum width (difference in x-coordinates) + max_width = np.nanmax(hull_pts[:, 0]) - np.nanmin(hull_pts[:, 0]) + return max_width -def get_chull_max_height( - pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]] -): +def get_chull_max_height(hull: Union[np.ndarray, ConvexHull]) -> float: """Get maximum height of convex hull. Args: - pts: landmark points, or convex hull, or tuple of convex hull results + hull: landmark points or a precomputed convex hull. Return: - Scalar of convex hull maximum height. + Scalar of convex hull maximum height. If the hull cannot be computed (e.g., + insufficient valid points), NaN is returned. """ - if type(pts) == tuple: - return pts[3] - elif type(pts) == ConvexHull: - hull = pts + # If hull is None, return NaN + if hull is None: + return np.nan + + # If the input is a ConvexHull object, use it directly + if isinstance(hull, ConvexHull): + hull = hull else: - hull = get_convhull(pts) + # Otherwise, compute the convex hull + hull = get_convhull(hull) + + # If no valid convex hull could be computed, return NaN if hull is None: return np.nan - pts = pts.reshape(-1, 2) - pts = pts[~(np.isnan(pts).any(axis=-1))] - max_height = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]) + + # Use the convex hull's vertices to compute the maximum height + max_height = np.nanmax(hull.points[hull.vertices, 1]) - np.nanmin( + hull.points[hull.vertices, 1] + ) + return max_height -def get_chull_line_lengths(pts: Union[np.ndarray, ConvexHull]) -> np.ndarray: - """Get the convex hull line lengths per frame. +def get_chull_line_lengths(hull: Union[np.ndarray, ConvexHull]) -> np.ndarray: + """Get the pairwise distances between all vertices of the convex hull. Args: - pts: Root landmarks as array of shape (..., 2) or ConvexHull object. + hull: Root landmarks as array of shape (..., 2) or a ConvexHull object. Returns: - Lengths of lines connecting any two vertices on the convex hull. - If the convex hull fitting fails, NaNs are returned. + An array containing the pairwise distances between all vertices of the convex + hull. If the convex hull fitting fails, an empty array is returned. """ - hull = pts if type(pts) == ConvexHull else get_convhull(pts) - + # If hull is None, return NaN if hull is None: return np.nan - # Lengths of lines connecting any two vertices on the convex hull + # Ensure pts is a ConvexHull object, otherwise get the convex hull + hull = hull if isinstance(hull, ConvexHull) else get_convhull(hull) + + if hull is None: + return np.array([]) + + # Compute the pairwise distances between all vertices of the convex hull chull_line_lengths = pdist(hull.points[hull.vertices], "euclidean") return chull_line_lengths diff --git a/sleap_roots/graphpipeline.py b/sleap_roots/graphpipeline.py deleted file mode 100644 index a6e3200..0000000 --- a/sleap_roots/graphpipeline.py +++ /dev/null @@ -1,712 +0,0 @@ -"""Extract traits based on the networkx graph.""" - -import numpy as np -import pandas as pd -import os -from typing import List -from fractions import Fraction -from pathlib import Path -from sleap_roots.traitsgraph import get_traits_graph -from sleap_roots.angle import get_root_angle -from sleap_roots.bases import ( - get_bases, - get_base_ct_density, - get_base_length, - get_base_length_ratio, - get_base_median_ratio, - get_base_tip_dist, - get_base_xs, - get_base_ys, - get_grav_index, - get_lateral_count, - get_primary_depth, - get_root_lengths, - get_root_pair_widths_projections, -) -from sleap_roots.convhull import ( - get_chull_area, - get_chull_line_lengths, - get_chull_max_width, - get_chull_max_height, - get_chull_perimeter, - get_convhull_features, -) -from sleap_roots.ellipse import ( - fit_ellipse, - get_ellipse_a, - get_ellipse_b, - get_ellipse_ratio, -) -from sleap_roots.networklength import ( - get_bbox, - get_network_distribution_ratio, - get_network_distribution, - get_network_solidity, - get_network_width_depth_ratio, -) -from sleap_roots.points import get_all_pts_array, get_all_pts -from sleap_roots.scanline import ( - count_scanline_intersections, - get_scanline_first_ind, - get_scanline_last_ind, -) -from sleap_roots.series import Series, find_all_series -from sleap_roots.summary import get_summary -from sleap_roots.tips import get_tips, get_tip_xs, get_tip_ys -from typing import Dict, Tuple -import warnings - - -SCALAR_TRAITS = ( - "primary_angle_proximal", - "primary_angle_distal", - "primary_length", - "primary_base_tip_dist", - "primary_depth", - "lateral_count", - "grav_index", - "base_length", - "base_length_ratio", - "primary_tip_pt_y", - "base_median_ratio", - "base_ct_density", - "chull_perimeter", - "chull_area", - "chull_max_width", - "chull_max_height", - "ellipse_a", - "ellipse_b", - "ellipse_ratio", - "network_width_depth_ratio", - "network_solidity", - "network_length_lower", - "network_distribution_ratio", - "scanline_first_ind", - "scanline_last_ind", -) - -NON_SCALAR_TRAITS = ( - "lateral_angles_proximal", - "lateral_angles_distal", - "lateral_lengths", - "stem_widths", - "lateral_base_xs", - "lateral_base_ys", - "lateral_tip_xs", - "lateral_tip_ys", - "chull_line_lengths", - "scanline_intersection_counts", -) - - -warnings.filterwarnings( - "ignore", - message="invalid value encountered in intersection", - category=RuntimeWarning, - module="shapely", -) -warnings.filterwarnings( - "ignore", message="All-NaN slice encountered", category=RuntimeWarning -) -warnings.filterwarnings( - "ignore", message="All-NaN axis encountered", category=RuntimeWarning -) -warnings.filterwarnings( - "ignore", - message="Degrees of freedom <= 0 for slice.", - category=RuntimeWarning, - module="numpy", -) -warnings.filterwarnings( - "ignore", message="Mean of empty slice", category=RuntimeWarning -) -warnings.filterwarnings( - "ignore", - message="invalid value encountered in sqrt", - category=RuntimeWarning, - module="skimage", -) -warnings.filterwarnings( - "ignore", - message="invalid value encountered in double_scalars", - category=RuntimeWarning, -) -warnings.filterwarnings( - "ignore", - message="invalid value encountered in scalar divide", - category=RuntimeWarning, - module="ellipse", -) - - -def get_traits_value_frame( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - pts_all_array: np.ndarray, - pts_all_list: list, - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: float = 2 / 3, - monocots: bool = False, -) -> Dict: - """Get SLEAP traits per frame based on graph. - - Args: - primary_pts: primary points - lateral_pts: lateral points - pts_all_array: all points in array format - pts_all_list: all points in list format - stem_width_tolerance: difference in projection norm between right and left side. - n_line: number of scan lines, np.nan for no interaction. - network_fraction: length found in the lower fration value of the network. - monocots: Boolean value, where false is dicot (default), true is rice. - - Return: - A dictionary with all traits per frame. - """ - trait_map = { - # get_bases(pts: np.ndarray,monocots) -> np.ndarray - "primary_base_pt": (get_bases, [primary_pts, monocots]), - # get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray - "primary_angle_proximal": (get_root_angle, [primary_pts, True, 0]), - "primary_angle_distal": (get_root_angle, [primary_pts, False, 0]), - # get_root_lengths(pts: np.ndarray) -> np.ndarray - "primary_length": (get_root_lengths, [primary_pts]), - # get_tips(pts) - "primary_tip_pt": (get_tips, [primary_pts]), - # fit_ellipse(pts: np.ndarray) -> Tuple[float, float, float] - "ellipse": (fit_ellipse, [pts_all_array]), - # get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float] - "bounding_box": (get_bbox, [pts_all_array]), - # get_root_pair_widths_projections(lateral_pts, primary_pts, tolerance,monocots) - "stem_widths": ( - get_root_pair_widths_projections, - [lateral_pts, primary_pts, stem_width_tolerance, monocots], - ), - # get_convhull_features(pts: Union[np.ndarray, ConvexHull]) -> Tuple[float, float, float, float] - "convex_hull": (get_convhull_features, [pts_all_array]), - # get_lateral_count(pts: np.ndarray) - "lateral_count": (get_lateral_count, [lateral_pts]), - # # get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray - "lateral_angles_proximal": (get_root_angle, [lateral_pts, True, 0]), - "lateral_angles_distal": (get_root_angle, [lateral_pts, False, 0]), - # get_root_lengths(pts: np.ndarray) -> np.ndarray - "lateral_lengths": (get_root_lengths, [lateral_pts]), - # get_bases(pts: np.ndarray,monocots) -> np.ndarray - "lateral_base_pts": (get_bases, [lateral_pts, monocots]), - # get_tips(pts) - "lateral_tip_pts": (get_tips, [lateral_pts]), - # get_base_ys(pts: np.ndarray) -> np.ndarray - # or just based on primary_base_pt, but the primary_base_pt trait must generate before - # "primary_base_pt_y": (get_pt_ys, [data["primary_base_pt"]]), - "primary_base_pt_y": (get_base_ys, [primary_pts]), - # get_base_ct_density(primary_pts, lateral_pts) - "base_ct_density": (get_base_ct_density, [primary_pts, lateral_pts, monocots]), - # get_network_solidity(primary_pts: np.ndarray, lateral_pts: np.ndarray, pts_all_array: np.ndarray, monocots: bool = False,) -> float - "network_solidity": ( - get_network_solidity, - [primary_pts, lateral_pts, pts_all_array, monocots], - ), - # get_network_distribution_ratio(primary_pts: np.ndarray,lateral_pts: np.ndarray,pts_all_array: np.ndarray,fraction: float = 2 / 3, monocots: bool = False) -> float: - "network_distribution_ratio": ( - get_network_distribution_ratio, - [primary_pts, lateral_pts, pts_all_array, network_fraction, monocots], - ), - # get_network_distribution(primary_pts: np.ndarray,lateral_pts: np.ndarray,pts_all_array: np.ndarray,fraction: float = 2 / 3, monocots: bool = False) -> float: - "network_length_lower": ( - get_network_distribution, - [primary_pts, lateral_pts, pts_all_array, network_fraction, monocots], - ), - # get_tip_ys(pts: np.ndarray) -> np.ndarray - "primary_tip_pt_y": (get_tip_ys, [primary_pts]), - # get_ellipse_a(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) - "ellipse_a": (get_ellipse_a, [pts_all_array]), - # get_ellipse_b(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) - "ellipse_b": (get_ellipse_b, [pts_all_array]), - # get_network_width_depth_ratio(pts: np.ndarray) -> float - "network_width_depth_ratio": (get_network_width_depth_ratio, [pts_all_array]), - # get_chull_perimeter(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_perimeter": (get_chull_perimeter, [pts_all_array]), - # get_chull_area(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_area": (get_chull_area, [pts_all_array]), - # get_chull_max_width(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_max_width": (get_chull_max_width, [pts_all_array]), - # get_chull_max_height(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_max_height": (get_chull_max_height, [pts_all_array]), - # get_chull_line_lengths(pts: Union[np.ndarray, ConvexHull]) -> np.ndarray - "chull_line_lengths": (get_chull_line_lengths, [pts_all_array]), - # count_scanline_intersections(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080,width: int = 2048,n_line: int = 50,monocots: bool = False,) -> np.ndarray - "scanline_intersection_counts": ( - count_scanline_intersections, - [primary_pts, lateral_pts, 1080, 2048, 50, monocots], - ), - # get_base_xs(pts: np.ndarray) -> np.ndarray - "lateral_base_xs": (get_base_xs, [lateral_pts, monocots]), - # get_base_ys(pts: np.ndarray) -> np.ndarray - "lateral_base_ys": (get_base_ys, [lateral_pts, monocots]), - # get_tip_xs(pts: np.ndarray) -> np.ndarray - "lateral_tip_xs": (get_tip_xs, [lateral_pts]), - # get_tip_ys(pts: np.ndarray) -> np.ndarray - "lateral_tip_ys": (get_tip_ys, [lateral_pts]), - # get_base_tip_dist(pts: np.ndarray) -> np.ndarray - "primary_base_tip_dist": (get_base_tip_dist, [primary_pts]), - # get_primary_depth(primary_pts) - "primary_depth": (get_primary_depth, [primary_pts]), - # get_base_median_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray) - "base_median_ratio": ( - get_base_median_ratio, - [primary_pts, lateral_pts, monocots], - ), - # get_ellipse_ratio(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) - "ellipse_ratio": (get_ellipse_ratio, [pts_all_array]), - # get_scanline_last_ind(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False) - "scanline_last_ind": ( - get_scanline_last_ind, - [primary_pts, lateral_pts, 1080, 2048, n_line, monocots], - ), - # get_scanline_first_ind(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False) - "scanline_first_ind": ( - get_scanline_first_ind, - [primary_pts, lateral_pts, 1080, 2048, n_line, monocots], - ), - # get_base_length(pts: np.ndarray) - "base_length": (get_base_length, [lateral_pts, monocots]), - # get_grav_index(pts: np.ndarray) - "grav_index": (get_grav_index, [primary_pts]), - # get_base_length_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray) - "base_length_ratio": (get_base_length_ratio, [primary_pts, lateral_pts]), - } - - dts = get_traits_graph() - - data = {} - for trait_name in dts: - fn, inputs = trait_map[trait_name] - fn_outputs = fn(*[input_trait for input_trait in inputs]) - if type(fn_outputs) == tuple: - fn_outputs = np.array(fn_outputs).reshape((1, -1)) - if isinstance(fn_outputs, (np.floating, float)) or isinstance( - fn_outputs, (np.integer, int) - ): - fn_outputs = np.array(fn_outputs)[np.newaxis] - data[trait_name] = fn_outputs - return data - - -def get_traits_value_plant( - h5, - monocots: bool = False, - primary_name: str = "primary_multi_day", - lateral_name: str = "lateral_3_nodes", - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: float = 2 / 3, - write_csv: bool = False, - csv_suffix: str = ".traits.csv", -) -> Tuple[Dict, pd.DataFrame, str]: - """Get detailed SLEAP traits for every frame of a plant, based on the graph. - - Args: - h5: The h5 file representing the plant image series. - monocots: A boolean value indicating whether the plant is a monocot (True) - or a dicot (False) (default). - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. - stem_width_tolerance: The difference in the projection norm between - the right and left side of the stem. - n_line: The number of scan lines. Use np.nan for no interaction. - network_fraction: The length found in the lower fraction value of the network. - write_csv: A boolean value. If True, it writes per plant detailed - CSVs with traits for every instance on every frame. - csv_suffix: If write_csv=True, the CSV file will be saved with the - h5 path + csv_suffix. - - Returns: - A tuple containing a dictionary and a DataFrame with all traits per plant, - and the plant name. The Dataframe has root traits per instance and frame - where each row corresponds to a frame in the H5 file. The plant_name is - given by the h5 file. - """ - plant = Series.load(h5, primary_name=primary_name, lateral_name=lateral_name) - plant_name = plant.series_name - # get number of frames per plant - n_frame = len(plant) - - data_plant = [] - # get traits for each frames in a row - for frame in range(n_frame): - primary, lateral = plant[frame] - - gt_instances_pr = primary.user_instances + primary.unused_predictions - gt_instances_lr = lateral.user_instances + lateral.unused_predictions - - if len(gt_instances_lr) == 0: - lateral_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - lateral_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) - - if len(gt_instances_pr) == 0: - primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) - - pts_all_array = get_all_pts_array(plant=plant, frame=frame, monocots=False) - if len(pts_all_array) == 0: - pts_all_array = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - pts_all_list = [] - - if get_root_lengths(primary_pts).shape[0] > 0 and not len(gt_instances_pr) == 0: - max_length_idx = np.nanargmax(get_root_lengths(primary_pts)) - long_primary_pts = primary_pts[max_length_idx] - primary_pts = np.reshape( - long_primary_pts, - (1, long_primary_pts.shape[0], long_primary_pts.shape[1]), - ) - else: - # if no primary root, just give two nan points - primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - - data = get_traits_value_frame( - primary_pts, - lateral_pts, - pts_all_array, - pts_all_list, - stem_width_tolerance, - n_line, - network_fraction, - monocots, - ) - - data["plant_name"] = plant_name - data["frame_idx"] = frame - data_plant.append(data) - data_plant_df = pd.DataFrame(data_plant) - - # reorganize the column position - column_names = data_plant_df.columns.tolist() - column_names = [column_names[-2]] + [column_names[-1]] + column_names[:-2] - data_plant_df = data_plant_df[column_names] - - # convert the data in scalar column to the value without [] - columns_to_convert = data_plant_df.columns[ - data_plant_df.apply( - lambda x: all( - isinstance(val, np.ndarray) and val.shape == (1,) for val in x - ) - ) - ] - data_plant_df[columns_to_convert] = data_plant_df[columns_to_convert].apply( - lambda x: x.apply(lambda val: val[0]) - ) - - if write_csv: - csv_name = Path(h5).with_suffix(f"{csv_suffix}") - data_plant_df.to_csv(csv_name, index=False) - return data_plant, data_plant_df, plant_name - - -def get_traits_value_plant_summary( - h5, - monocots: bool = False, - primary_name: str = "longest_3do_6nodes", - lateral_name: str = "main_3do_6nodes", - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: float = 2 / 3, - write_csv: bool = False, - csv_suffix: str = ".traits.csv", - write_summary_csv: bool = False, - summary_csv_suffix: str = ".summary_traits.csv", -) -> pd.DataFrame: - """Get summary statistics of SLEAP traits per plant based on graph. - - Args: - h5: The h5 file representing the plant image series. - monocots: A boolean value indicating whether the plant is a monocot (True) - or a dicot (False) (default). - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. - stem_width_tolerance: The difference in the projection norm between - the right and left side of the stem. - n_line: The number of scan lines. Use np.nan for no interaction. - network_fraction: The length found in the lower fraction value of the network. - write_csv: A boolean value. If True, it writes per plant detailed - CSVs with traits for every instance on every frame. - csv_suffix: If write_csv=True, the CSV file will be saved with the name - h5 path + csv_suffix. - write_summary_csv: Boolean value, where true is write summarized csv file. - summary_csv_suffix: If write_summary_csv=True, the CSV file with the summary - statistics per plant will be saved with the name - h5 path + summary_csv_suffix. - - Return: - A DataFrame with summary statistics of all traits per plant. - """ - data_plant, data_plant_df, plant_name = get_traits_value_plant( - h5, - monocots, - primary_name, - lateral_name, - stem_width_tolerance, - n_line, - network_fraction, - write_csv, - csv_suffix, - ) - - # get summarized non-scalar traits per frame - data_plant_frame_summary = [] - data_plant_frame_summary_non_scalar = {} - - for i in range(len(NON_SCALAR_TRAITS)): - trait = data_plant_df[NON_SCALAR_TRAITS[i]] - - if not trait.isna().all(): - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmin" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmin(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmax" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmax(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmean" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmean(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmedian" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmedian(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fstd" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanstd(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc5" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 5)) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc25" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else ( - np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 25) - ) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc75" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else ( - np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 75) - ) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc95" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else ( - np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 95) - ) - ) - else: - data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fmin"] = np.nan - data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fmax"] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmean" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmedian" - ] = np.nan - data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fstd"] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc5" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc25" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc75" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc95" - ] = np.nan - - # get summarized scalar traits per plant - column_names = data_plant_df.columns.tolist() - data_plant_frame_summary = {} - for i in range(len(SCALAR_TRAITS)): - if SCALAR_TRAITS[i] in column_names: - trait = data_plant_df[SCALAR_TRAITS[i]] - if trait.shape[0] > 0: - if not ( - isinstance(trait[0], (np.floating, float)) - or isinstance(trait[0], (np.integer, int)) - ): - values = np.array([element[0] for element in trait]) - trait = values - trait = trait.astype(float) - trait = np.reshape(trait, (len(trait), 1)) - ( - trait_min, - trait_max, - trait_mean, - trait_median, - trait_std, - trait_prc5, - trait_prc25, - trait_prc75, - trait_prc95, - ) = get_summary(trait) - - data_plant_frame_summary[SCALAR_TRAITS[i] + "_min"] = trait_min - data_plant_frame_summary[SCALAR_TRAITS[i] + "_max"] = trait_max - data_plant_frame_summary[SCALAR_TRAITS[i] + "_mean"] = trait_mean - data_plant_frame_summary[SCALAR_TRAITS[i] + "_median"] = trait_median - data_plant_frame_summary[SCALAR_TRAITS[i] + "_std"] = trait_std - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc5"] = trait_prc5 - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc25"] = trait_prc25 - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc75"] = trait_prc75 - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc95"] = trait_prc95 - - # append the summarized non-scalar traits per plant - data_plant_frame_summary_key = list(data_plant_frame_summary_non_scalar.keys()) - for j in range(len(data_plant_frame_summary_non_scalar)): - trait = data_plant_frame_summary_non_scalar[data_plant_frame_summary_key[j]] - ( - trait_min, - trait_max, - trait_mean, - trait_median, - trait_std, - trait_prc5, - trait_prc25, - trait_prc75, - trait_prc95, - ) = get_summary(trait) - - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_min"] = trait_min - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_max"] = trait_max - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_mean"] = trait_mean - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_median" - ] = trait_median - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_std"] = trait_std - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_prc5"] = trait_prc5 - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_prc25" - ] = trait_prc25 - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_prc75" - ] = trait_prc75 - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_prc95" - ] = trait_prc95 - data_plant_frame_summary["plant_name"] = [plant_name] - data_plant_frame_summary_df = pd.DataFrame(data_plant_frame_summary) - - # reorganize the column position - column_names = data_plant_frame_summary_df.columns.tolist() - column_names = [column_names[-1]] + column_names[:-1] - data_plant_frame_summary_df = data_plant_frame_summary_df[column_names] - - if write_summary_csv: - summary_csv_name = Path(h5).with_suffix(f"{summary_csv_suffix}") - data_plant_frame_summary_df.to_csv(summary_csv_name, index=False) - return data_plant_frame_summary_df - - -def get_all_plants_traits( - data_folders: List[str], - primary_name: str, - lateral_name: str, - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: Fraction = Fraction(2, 3), - write_per_plant_details: bool = False, - per_plant_details_csv_suffix: str = ".traits.csv", - write_per_plant_summary: bool = False, - per_plant_summary_csv_suffix: str = ".summary_traits.csv", - monocots: bool = False, - all_plants_csv_name: str = "all_plants_traits.csv", -) -> pd.DataFrame: - """Get a DataFrame with summary traits from all plants in the given data folders. - - Args: - h5: The h5 file representing the plant image series. - monocots: A boolean value indicating whether the plant is a monocot (True) - or a dicot (False) (default). - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. - stem_width_tolerance: The difference in the projection norm between - the right and left side of the stem. - n_line: The number of scan lines. Use np.nan for no interaction. - network_fraction: The length found in the lower fraction value of the network. - write_per_plant_details: A boolean value. If True, it writes per plant detailed - CSVs with traits for every instance. - per_plant_details_csv_suffix: If write_csv=True, the CSV file will be saved - with the name h5 path + csv_suffix. - write_per_plant_summary: A boolean value. If True, it writes per plant summary - CSVs. - per_plant_summary_csv_suffix: If write_summary_csv=True, the CSV file with the - summary statistics per plant will be saved with the name - h5 path + summary_csv_suffix. - all_plants_csv_name: The name of the output CSV file containing all plants' - summary traits. - - Returns: - A pandas DataFrame with summary root traits for all plants in the data folders. - Each row is a sample. - """ - h5_series = find_all_series(data_folders) - - all_traits = [] - for h5 in h5_series: - plant_traits = get_traits_value_plant_summary( - h5, - monocots=monocots, - primary_name=primary_name, - lateral_name=lateral_name, - stem_width_tolerance=stem_width_tolerance, - n_line=n_line, - network_fraction=network_fraction, - write_csv=write_per_plant_details, - csv_suffix=per_plant_details_csv_suffix, - write_summary_csv=write_per_plant_summary, - summary_csv_suffix=per_plant_summary_csv_suffix, - ) - plant_traits["path"] = h5 - all_traits.append(plant_traits) - - all_traits_df = pd.concat(all_traits, ignore_index=True) - - all_traits_df.to_csv(all_plants_csv_name, index=False) - return all_traits_df diff --git a/sleap_roots/lengths.py b/sleap_roots/lengths.py new file mode 100644 index 0000000..ac258fc --- /dev/null +++ b/sleap_roots/lengths.py @@ -0,0 +1,169 @@ +"""Get length-related traits.""" +import numpy as np +from sleap_roots.bases import get_base_tip_dist +from typing import Optional + + +def get_max_length_pts(pts: np.ndarray) -> np.ndarray: + """Points of the root with maximum length (intended for primary root traits). + + Args: + pts (np.ndarray): Root landmarks as array of shape `(instances, nodes, 2)`. + + Returns: + np.ndarray: Array of points with shape `(nodes, 2)` from the root with maximum + length. + """ + # Return NaN points if the input array is empty + if len(pts) == 0: + return np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + + # Check if pts has the correct shape, raise error if it does not + if pts.ndim != 3 or pts.shape[2] != 2: + raise ValueError("Input array should have shape (instances, nodes, 2)") + + # Calculate the differences between consecutive points in each root + segment_diffs = np.diff(pts, axis=1) + + # Calculate the length of each segment (the Euclidean distance between consecutive + # points) + segment_lengths = np.linalg.norm(segment_diffs, axis=-1) + + # Sum the lengths of the segments for each root + total_lengths = np.nansum(segment_lengths, axis=-1) + + # Handle roots where all segment lengths are NaN, recording NaN in place of the + # total length for these roots + total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan + + # Return NaN points if all total lengths are NaN + if np.isnan(total_lengths).all(): + return np.array([[np.nan, np.nan]]) + + # Find the index of the root with the maximum total length + max_length_idx = np.nanargmax(total_lengths) + + # Return the points of the root with this index + return pts[max_length_idx] + + +def get_root_lengths(pts: np.ndarray) -> np.ndarray: + """Return root lengths for all roots in a frame. + + Args: + pts: Root landmarks as array of shape `(instances, nodes, 2)` or `(nodes, 2)`. + + Returns: + Array of root lengths of shape `(instances,)`. If there is no root, or the root + is one point only (all of the rest of the points are NaNs), an array of NaNs + with shape (len(pts),) is returned. This is also the case for non-contiguous + points. + """ + # If the input has shape `(nodes, 2)`, reshape it for consistency + if pts.ndim == 2: + pts = pts[np.newaxis, ...] + + # Get the (x,y) differences of segments for each instance + segment_diffs = np.diff(pts, axis=1) + # Get the lengths of each segment by taking the norm + segment_lengths = np.linalg.norm(segment_diffs, axis=-1) + # Add the segments together to get the total length using nansum + total_lengths = np.nansum(segment_lengths, axis=-1) + # Find the NaN segment lengths and record NaN in place of 0 when finding the total length + total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan + + # If there is 1 instance, return a scalar instead of an array of length 1 + if len(total_lengths) == 1: + return total_lengths[0] + + return total_lengths + + +def get_root_lengths_max(pts: np.ndarray) -> np.ndarray: + """Return maximum root length for all roots in a frame. + + Args: + pts: root landmarks as array of shape `(instance, nodes, 2)` or lengths + `(instances)`. + + Returns: + Scalar of the maximum root length. + """ + # If the pts are NaNs, return NaN + if np.isnan(pts).all(): + return np.nan + + if pts.ndim not in (1, 3): + raise ValueError( + "Input array must be 1-dimensional (n_lengths) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + # If the input array has 3 dimensions, calculate the root lengths, + # otherwise, assume the input array already contains the root lengths + if pts.ndim == 3: + root_lengths = get_root_lengths( + pts + ) # Assuming get_root_lengths returns an array of shape (instances) + max_length = np.nanmax(root_lengths) + else: + max_length = np.nanmax(pts) + + return max_length + + +def get_grav_index( + primary_length: Optional[float] = None, + primary_base_tip_dist: Optional[float] = None, + pts: Optional[np.ndarray] = None, +) -> float: + """Calculate the gravitropism index of a primary root. + + The gravitropism index quantifies the curviness of the root's growth. A higher + gravitropism index indicates a curvier root (less responsive to gravity), while a + lower index indicates a straighter root (more responsive to gravity). The index is + computed as the difference between the maximum primary root length and straight-line + distance from the base to the tip of the primary root, normalized by the root length. + + Args: + primary_length: Maximum length of the primary root. Used if `pts` is not + provided. + primary_base_tip_dist: The straight-line distance from the base to the tip of + the primary root. Used if `pts` is not provided. + pts: Landmarks of the primary root of shape `(instances, nodes, 2)`. If + provided, `primary_length` and `primary_base_tip_dist` are ignored. + + Returns: + float: Gravitropism index of the primary root, quantifying its curviness. + """ + # Use provided scalar values if available + if primary_length is not None and primary_base_tip_dist is not None: + max_primary_length = primary_length + max_base_tip_distance = primary_base_tip_dist + + # Use provided pts array to compute required values if available + elif pts is not None: + if np.isnan(pts).all(): + return np.nan + primary_length_max = get_root_lengths_max(pts=pts) + primary_base_tip_dist = get_base_tip_dist(pts=pts) + max_primary_length = np.nanmax(primary_length_max) + max_base_tip_distance = np.nanmax(primary_base_tip_dist) + + else: + raise ValueError( + "Either both primary_length and primary_base_tip_dist, or pts" + "must be provided." + ) + + # Check for invalid values (NaN or zero lengths) + if ( + np.isnan(max_primary_length) + or np.isnan(max_base_tip_distance) + or max_primary_length == 0 + ): + return np.nan + + # Calculate and return gravitropism index + grav_index = (max_primary_length - max_base_tip_distance) / max_primary_length + return grav_index diff --git a/sleap_roots/networklength.py b/sleap_roots/networklength.py index 1bda1dc..c089ec3 100644 --- a/sleap_roots/networklength.py +++ b/sleap_roots/networklength.py @@ -2,9 +2,8 @@ import numpy as np from shapely import LineString, Polygon -from sleap_roots.bases import get_root_lengths -from sleap_roots.convhull import get_convhull_features -from typing import Tuple +from sleap_roots.lengths import get_root_lengths, get_max_length_pts +from typing import Optional, Tuple, Union def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: @@ -36,17 +35,22 @@ def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: return bbox -def get_network_width_depth_ratio(pts: np.ndarray) -> float: +def get_network_width_depth_ratio( + pts: Union[np.ndarray, Tuple[float, float, float, float]] +) -> float: """Return width to depth ratio of bounding box for root network. Args: - pts: Root landmarks as array of shape (..., 2). + pts: Root landmarks as array of shape (..., 2) or boundary box. Returns: Float of bounding box width to depth ratio of root network. """ # get the bounding box - bbox = get_bbox(pts) + if type(pts) == tuple: + bbox = pts + else: + bbox = get_bbox(pts) width, height = bbox[2], bbox[3] if width > 0 and height > 0: ratio = width / height @@ -55,32 +59,62 @@ def get_network_width_depth_ratio(pts: np.ndarray) -> float: return np.nan -def get_network_solidity( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - pts_all_array: np.ndarray, +def get_network_length( + primary_length: float, + lateral_lengths: Union[float, np.ndarray], monocots: bool = False, ) -> float: - """Return the total network length divided by the network convex area. + """Return the total root network length given primary and lateral root lengths. Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - pts_all_array: primary and lateral root landmarks. - monocots: a boolean value, where True is rice. + primary_length: Primary root length. + lateral_lengths: Either a float representing the length of a single lateral + root or an array of lateral root lengths with shape `(instances,)`. + monocots: A boolean value, where True is rice. Returns: - Float of the total network length divided by the network convex area. + Total length of root network. """ - # get the total network length - network_length = get_network_length(primary_pts, lateral_pts, monocots) + # Ensure primary_length is a scalar + if not isinstance(primary_length, (float, np.float64)): + raise ValueError("Input primary_length must be a scalar value.") + + # Ensure lateral_lengths is either a scalar or has the correct shape + if not ( + isinstance(lateral_lengths, (float, np.float64)) or lateral_lengths.ndim == 1 + ): + raise ValueError( + "Input lateral_lengths must be a scalar or have shape (instances,)." + ) + + # Calculate the total lateral root length using np.nansum + total_lateral_length = np.nansum(lateral_lengths) + + if monocots: + length = total_lateral_length + else: + # Calculate the total root network length using np.nansum so the total length + # will not be NaN if one of primary or lateral lengths are NaN + length = np.nansum([primary_length, total_lateral_length]) + + return length - # get the convex hull area - convhull_features = get_convhull_features(pts_all_array) - conv_area = convhull_features[1] - if network_length > 0 and conv_area > 0: - ratio = network_length / conv_area +def get_network_solidity( + network_length: float, + chull_area: float, +) -> float: + """Return the total network length divided by the network convex area. + + Args: + network_length: Total root length of network. + chull_area: Convex hull area. + + Returns: + Float of the total network length divided by the network convex area. + """ + if network_length > 0 and chull_area > 0: + ratio = network_length / chull_area return ratio else: return np.nan @@ -89,133 +123,137 @@ def get_network_solidity( def get_network_distribution( primary_pts: np.ndarray, lateral_pts: np.ndarray, - pts_all_array: np.ndarray, + bounding_box: Tuple[float, float, float, float], fraction: float = 2 / 3, monocots: bool = False, ) -> float: """Return the root length in the lower fraction of the plant. Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - pts_all_array: primary and lateral root landmarks. - fraction: the network length found in the lower fration value of the network. - monocots: a boolean value, where True is rice. + primary_pts: Array of primary root landmarks. Can have shape `(nodes, 2)` or + `(1, nodes, 2)`. + lateral_pts: Array of lateral root landmarks with shape `(instances, nodes, 2)`. + bounding_box: Tuple in the form `(left_x, top_y, width, height)`. + fraction: Lower fraction value. Defaults to 2/3. + monocots: A boolean value, where True indicates rice. Defaults to False. Returns: - Float of the root network length in the lower fraction of the plant. + Root network length in the lower fraction of the plant. """ - # get the bounding box - bbox = get_bbox(pts_all_array) - left_x, top_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3] + # Input validation + if primary_pts.ndim not in [2, 3]: + raise ValueError( + "primary_pts should have a shape of `(nodes, 2)` or `(1, nodes, 2)`." + ) + + if primary_pts.ndim == 2 and primary_pts.shape[-1] != 2: + raise ValueError("primary_pts should have a shape of `(nodes, 2)`.") + + if primary_pts.ndim == 3 and primary_pts.shape[-1] != 2: + raise ValueError("primary_pts should have a shape of `(1, nodes, 2)`.") - # get the bounding box of the lower fraction - lower_height = bbox[3] * fraction + if lateral_pts.ndim != 3 or lateral_pts.shape[-1] != 2: + raise ValueError("lateral_pts should have a shape of `(instances, nodes, 2)`.") + + if len(bounding_box) != 4: + raise ValueError( + "bounding_box should be in the form `(left_x, top_y, width, height)`." + ) + + # Make sure the longest primary root is used + if primary_pts.ndim == 3: + primary_pts = get_max_length_pts(primary_pts) # shape is (nodes, 2) + + # Make primary_pts and lateral_pts have the same dimension of 3 + primary_pts = ( + primary_pts[np.newaxis, :, :] if primary_pts.ndim == 2 else primary_pts + ) + + # Filter out NaN values + primary_pts = [root[~np.isnan(root).any(axis=1)] for root in primary_pts] + lateral_pts = [root[~np.isnan(root).any(axis=1)] for root in lateral_pts] + + # Collate root points. + all_roots = primary_pts + lateral_pts if not monocots else lateral_pts + + # Get the vertices of the bounding box + left_x, top_y, width, height = bounding_box + + # Calculate the bounding box of the lower fraction + lower_height = height * fraction if np.isnan(lower_height): return np.nan - lower_bbox = (bbox[0], bbox[1] + (bbox[3] - lower_height), bbox[2], lower_height) - # convert bounding box to polygon - polygon = Polygon( + # Convert lower bounding box to polygon + # Vertices are in counter-clockwise order + lower_box = Polygon( [ - [bbox[0], bbox[1] + (bbox[3] - lower_height)], - [bbox[0], bbox[1] + height], - [bbox[0] + width, bbox[1] + height], - [bbox[0] + width, bbox[1] + (bbox[3] - lower_height)], + [left_x, top_y + (height - lower_height)], + [left_x, top_y + height], + [left_x + width, top_y + height], + [left_x + width, top_y + (height - lower_height)], ] ) - # filter out the nan nodes - if monocots: - points = list(primary_pts) - else: - points = list(primary_pts) + list(lateral_pts) - pts_nnan = [] - for j in range(len(points)): - pts_j = points[j][~np.isnan(points[j]).any(axis=1)] - pts_nnan.append(pts_j) - - # get length of lines within the lower bounding box - root_length = 0 - for j in range(len(points)): - # filter out nan nodes - pts_j = points[j][~np.isnan(points[j]).any(axis=1)] - if pts_j.shape[0] > 1: - linestring = LineString(pts_j) - if linestring.intersection(polygon): - intersection = linestring.intersection(polygon) - root_length += ( - intersection.length if ~np.isnan(intersection.length) else 0 - ) - - return root_length + # Calculate length of roots within the lower bounding box + network_length = 0 + for root in all_roots: + root_poly = LineString(root) + lower_intersection = root_poly.intersection(lower_box) + root_length = lower_intersection.length + network_length += root_length if ~np.isnan(root_length) else 0 - -def get_network_length( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - monocots: bool = False, -) -> float: - """Return all primary or lateral root length one frame. - - Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - monocots: a boolean value, where True is rice. - - Returns: - Float of primary or lateral root network length. - """ - if ( - np.sum(get_root_lengths(primary_pts)) > 0 - or np.sum(get_root_lengths(lateral_pts)) > 0 - ): - if monocots: - length = np.nansum(get_root_lengths(primary_pts)) - else: - length = np.nansum(get_root_lengths(primary_pts)) + np.nansum( - get_root_lengths(lateral_pts) - ) - return length - else: - return 0 + return network_length def get_network_distribution_ratio( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - pts_all_array: np.ndarray, + primary_length: float, + lateral_lengths: Union[float, np.ndarray], + network_length_lower: float, fraction: float = 2 / 3, monocots: bool = False, ) -> float: """Return ratio of the root length in the lower fraction over all root length. Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - pts_all_array: primary and lateral root landmarks. - fraction: the network length found in the lower fration value of the network. - monocots: a boolean value, where True is rice. + primary_length: Primary root length. + lateral_lengths: Lateral root lengths. Can be a single float (for one root) + or an array of floats (for multiple roots). + network_length_lower: The root length in the lower network. + fraction: The fraction of the network considered as 'lower'. Defaults to 2/3. + monocots: A boolean value, where True indicates rice. Defaults to False. Returns: Float of ratio of the root network length in the lower fraction of the plant over all root length. """ - if ( - np.sum(get_root_lengths(primary_pts)) + np.sum(get_root_lengths(lateral_pts)) - > 0 - ): - if monocots: - ratio = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction, monocots - ) / (np.sum(get_root_lengths(primary_pts))) - else: - ratio = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction - ) / ( - np.sum(get_root_lengths(primary_pts)) - + np.sum(get_root_lengths(lateral_pts)) - ) - return ratio + # Ensure primary_length is a scalar + if not isinstance(primary_length, (float, np.float64)): + raise ValueError("Input primary_length must be a scalar value.") + + # Ensure lateral_lengths is either a scalar or a 1-dimensional array + if not isinstance(lateral_lengths, (float, np.float64, np.ndarray)): + raise ValueError( + "Input lateral_lengths must be a scalar or a 1-dimensional array." + ) + + # If lateral_lengths is an ndarray, it must be one-dimensional + if isinstance(lateral_lengths, np.ndarray) and lateral_lengths.ndim != 1: + raise ValueError("Input lateral_lengths array must have shape (instances,).") + + # Ensure network_length_lower is a scalar + if not isinstance(network_length_lower, (float, np.float64)): + raise ValueError("Input network_length_lower must be a scalar value.") + + # Calculate the total lateral root length + total_lateral_length = np.nansum(lateral_lengths) + + # Determine total root length based on monocots flag + if monocots: + total_root_length = total_lateral_length else: - return np.nan + total_root_length = np.nansum([primary_length, total_lateral_length]) + + # Calculate the ratio + ratio = network_length_lower / total_root_length + return ratio diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 1beccc5..2620939 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -1,115 +1,52 @@ -"""Get points function.""" +"""Get traits related to the points.""" import numpy as np -from sleap_roots.bases import get_root_lengths -from sleap_roots.series import Series -from typing import List -def get_pt_ind(pts: np.ndarray, proximal: bool = True) -> np.ndarray: - """Find proximal/distal point index. +def get_all_pts_array( + primary_max_length_pts: np.ndarray, lateral_pts: np.ndarray, monocots: bool = False +) -> np.ndarray: + """Get all landmark points within a given frame as a flat array of coordinates. Args: - pts: Numpy array of points of shape (instances, point, 2). - proximal: Boolean value, where True is proximal (default), False is distal. + primary_max_length_pts: Points of the primary root with maximum length of shape + `(nodes, 2)`. + lateral_pts: Lateral root points of shape `(instances, nodes, 2)`. + monocots: If False (default), returns a combined array of primary and lateral + root points. If True, returns only lateral root points. Returns: - An array of shape (instances,) of proximal or distal point index. + A 2D array of shape (n_points, 2), containing the coordinates of all extracted + points. """ - pt_ind = [] - for i in range(pts.shape[0]): - ind = 1 if proximal else pts.shape[1] - 1 # set initial proximal/distal point - while np.isnan(pts[i, ind]).any(): - ind += 1 if proximal else -1 - if (ind == pts.shape[1] and proximal) or (ind == 0 and not proximal): - break - pt_ind.append(ind) - return pt_ind - + # Check if the input arrays have the right number of dimensions + if primary_max_length_pts.ndim != 2 or lateral_pts.ndim != 3: + raise ValueError( + "Input arrays should have the correct number of dimensions:" + "primary_max_length_pts should be 2-dimensional and lateral_pts should be" + "3-dimensional." + ) -def get_primary_pts(plant: Series, frame: int) -> np.ndarray: - """Get primary root points. + # Check if the last dimension of the input arrays has size 2 + # (representing x and y coordinates) + if primary_max_length_pts.shape[-1] != 2 or lateral_pts.shape[-1] != 2: + raise ValueError( + "The last dimension of the input arrays should have size 2, representing x" + "and y coordinates." + ) - Args: - plant: Series object representing a plant image series. - frame: frame index + # Flatten the arrays to 2D + primary_max_length_pts = primary_max_length_pts.reshape(-1, 2) + lateral_pts = lateral_pts.reshape(-1, 2) - Return: - An array of primary root points of shape (1, n_points, 2). - If more than one primary root is present, the longest will be used. - """ - # get the primary root points, if >1 primary roots, return the longest primary root - pts_pr = plant.get_primary_points(frame_idx=frame) - if len(pts_pr) == 0: - return pts_pr + # Combine points + if monocots: + pts_all_array = lateral_pts else: - max_length_idx = np.nanargmax(get_root_lengths(pts_pr)) - pts_pr = pts_pr[np.newaxis, max_length_idx] - return pts_pr - - -def get_lateral_pts(plant: Series, frame: int) -> np.ndarray: - """Get lateral root points. - - Args: - plant: Series object representing a plant image series. - frame: frame index - - Return: - An array of primary root points in shape (instance, point, 2) - """ - pts_lr = plant.get_lateral_points(frame_idx=frame) - return pts_lr - - -def get_all_pts(plant: Series, frame: int, monocots: bool = False) -> List[np.ndarray]: - """Get all points within a frame. + # Check if the data types of the arrays are compatible + if primary_max_length_pts.dtype != lateral_pts.dtype: + raise ValueError("Input arrays should have the same data type.") - Args: - plant: Series object representing a plant image series. - frame: frame index - rice: boolean value, where True is rice frame - monocots: If False (the default), returns primary and lateral points - combined. If True, only lateral root points will be returned. This is useful for - monocot species such as rice. - - Return: - A list of numpy arrays containing sets of points of shape - (n_instances, n_points, 2). - """ - # get primary and lateral root points - pts_pr = get_primary_pts(plant, frame).tolist() - pts_lr = get_lateral_pts(plant, frame).tolist() - - pts_all = pts_lr if monocots else pts_pr + pts_lr - - return pts_all - - -def get_all_pts_array(plant: Series, frame: int, monocots: bool = False) -> np.ndarray: - """Get all points within a frame as a flat array of coordinates. - - Args: - plant: Series object representing a plant image series. - frame: frame index - monocots: If False (the default), returns primary and lateral points - combined. If True, only lateral root points will be returned. This is useful for - monocot species such as rice. - - Return: - An array of all points (primary and optionally lateral) as an array of shape - (n_points, 2). - """ - # get primary and lateral root points - pts_pr = get_primary_pts(plant, frame) - pts_lr = get_lateral_pts(plant, frame) - - pts_all_array = ( - pts_lr.reshape(-1, 2) - if monocots - else np.concatenate( - (np.array(pts_pr).reshape(-1, 2), np.array(pts_lr).reshape(-1, 2)), axis=0 - ) - ) + pts_all_array = np.concatenate((primary_max_length_pts, lateral_pts), axis=0) return pts_all_array diff --git a/sleap_roots/scanline.py b/sleap_roots/scanline.py index 83a3b76..1d39598 100644 --- a/sleap_roots/scanline.py +++ b/sleap_roots/scanline.py @@ -2,121 +2,109 @@ import numpy as np import math -from shapely import LineString, Point def count_scanline_intersections( primary_pts: np.ndarray, lateral_pts: np.ndarray, - depth: int = 1080, + height: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False, ) -> np.ndarray: - """Get intersection points of roots and scan lines. + """Count intersections of roots with a series of horizontal scanlines. + + This function calculates the number of intersections between the provided + primary and lateral root points and a set of horizontal scanlines. The scanlines + are equally spaced across the specified height. Args: - primary_pts: Numpy array of primary points of shape (instances, nodes, 2). - lateral_pts: Numpy array of lateral points of shape (instances, nodes, 2). - depth: the depth of cylinder, or number of rows of the image. - width: the width of cylinder, or number of columns of the image. - n_line: number of scan lines. - monocots: whether True: only lateral roots (e.g., rice), or False: dicots + primary_pts: Array of primary root landmarks of shape `(nodes, 2)`. + Will be reshaped internally to `(1, nodes, 2)`. + lateral_pts: Array of lateral root landmarks with shape + `(instances, nodes, 2)`. + height: The height of the image or cylinder. Defaults to 1080. + width: The width of the image or cylinder. Defaults to 2048. + n_line: Number of scanlines to use. Defaults to 50. + monocots: If `True`, only uses lateral roots (e.g., for rice). + If `False`, uses both primary and lateral roots (e.g., for dicots). + Defaults to `False`. Returns: - An array with shape of (#Nline,) of intersection numbers of each scan line. + An array with shape `(n_line,)` representing the number of intersections + of roots with each scanline. """ - # connect the points to lines using shapely - if monocots: - points = list(lateral_pts) - else: - points = list(primary_pts) + list(lateral_pts) + # Input validation + if primary_pts.ndim != 2 or primary_pts.shape[-1] != 2: + raise ValueError("primary_pts should have a shape of `(nodes, 2)`.") + + if lateral_pts.ndim != 3 or lateral_pts.shape[-1] != 2: + raise ValueError("lateral_pts should have a shape of `(instances, nodes, 2)`.") - # calculate interval between two scan lines - n_interval = math.ceil(depth / (n_line - 1)) + # Reshape primary_pts to have three dimensions + primary_pts = primary_pts[np.newaxis, :, :] - intersection = [] + # Collate root points. + all_roots = list(primary_pts) + list(lateral_pts) if not monocots else lateral_pts + # Calculate the interval between two scanlines + interval = height / (n_line - 1) + + intersections = [] + + # Iterate over scanlines for i in range(n_line): - horizontal_line_y = n_interval * (i + 1) - intersection_line = 0 - - for j in range(len(points)): - intersection_counts_root = 0 - # filter out nan nodes - pts_j = np.array(points[j])[~np.isnan(points[j]).any(axis=1)] - current_root = 0 - if pts_j.shape[0] > 1: - for k in range(len(pts_j) - 1): - x1, y1 = pts_j[k] - x2, y2 = pts_j[k + 1] - if (y1 >= horizontal_line_y and y2 < horizontal_line_y) or ( - y1 < horizontal_line_y and y2 >= horizontal_line_y - ): - current_root += 1 - intersection_counts_root += current_root - intersection_line += intersection_counts_root - intersection.append(intersection_line) - return np.array(intersection) - - -def get_scanline_first_ind( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - depth: int = 1080, - width: int = 2048, - n_line: int = 50, - monocots: bool = False, -): + y_coord = interval * i + line_intersections = 0 + + for root_points in all_roots: + # Remove NaN values + valid_points = root_points[(~np.isnan(root_points)).any(axis=1)] + + if len(valid_points) > 1: + for j in range(len(valid_points) - 1): + y1 = valid_points[j][1] + y2 = valid_points[j + 1][1] + + if (y1 >= y_coord >= y2) or (y2 >= y_coord >= y1): + line_intersections += 1 + + intersections.append(line_intersections) + + return np.array(intersections) + + +def get_scanline_first_ind(scanline_intersection_counts: np.ndarray): """Get the index of count_scanline_interaction for the first interaction. Args: - primary_pts: Numpy array of primary points of shape (instances, nodes, 2). - lateral_pts: Numpy array of lateral points of shape (instances, nodes, 2). - depth: the depth of cylinder, or number of rows of the image. - width: the width of cylinder, or number of columns of the image. - n_line: number of scan lines, np.nan for no interaction. - monocots: whether True: only lateral roots (e.g., rice), or False: dicots. + scanline_intersection_counts: An array with shape of `(#Nline,)` of intersection + numbers of each scan line. Return: Scalar of count_scanline_interaction index for the first interaction. """ - count_scanline_interaction = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots - ) - if np.where((count_scanline_interaction > 0))[0].shape[0] > 0: - scanline_first_ind = np.where((count_scanline_interaction > 0))[0][0] + # get the first scanline index using scanline_intersection_counts + if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0: + scanline_first_ind = np.where((scanline_intersection_counts > 0))[0][0] return scanline_first_ind else: return np.nan -def get_scanline_last_ind( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - depth: int = 1080, - width: int = 2048, - n_line: int = 50, - monocots: bool = False, -): +def get_scanline_last_ind(scanline_intersection_counts: np.ndarray): """Get the index of count_scanline_interaction for the last interaction. Args: - primary_pts: Numpy array of primary points of shape (instances, nodes, 2). - lateral_pts: Numpy array of lateral points of shape (instances, nodes, 2). - depth: the depth of cylinder, or number of rows of the image. - width: the width of cylinder, or number of columns of the image. - n_line: number of scan lines, np.nan for no interaction. - monocots: whether True: only lateral roots (e.g., rice), or False: dicots. + scanline_intersection_counts: An array with shape of `(#Nline,)` of intersection + numbers of each scan line. Return: Scalar of count_scanline_interaction index for the last interaction. """ - count_scanline_interaction = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots - ) - if np.where((count_scanline_interaction > 0))[0].shape[0] > 0: - scanline_last_ind = np.where((count_scanline_interaction > 0))[0][-1] + # get the first scanline index using scanline_intersection_counts + if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0: + scanline_last_ind = np.where((scanline_intersection_counts > 0))[0][-1] return scanline_last_ind else: return np.nan diff --git a/sleap_roots/series.py b/sleap_roots/series.py index a434fad..45ac8ad 100644 --- a/sleap_roots/series.py +++ b/sleap_roots/series.py @@ -17,13 +17,15 @@ class Series: Attributes: h5_path: Path to the HDF5-formatted image series. - primary_labels: A `sleap.Labels` corresponding to the primary root predictions. - lateral_labels: A `sleap.Labels` corresponding to the lateral root predictions. + primary_labels: A `sio.Labels` corresponding to the primary root predictions. + lateral_labels: A `sio.Labels` corresponding to the lateral root predictions. + video: A `sio.Video` corresponding to the image series. """ h5_path: Optional[str] = None primary_labels: Optional[sio.Labels] = None lateral_labels: Optional[sio.Labels] = None + video: Optional[sio.Video] = None @classmethod def load( @@ -52,6 +54,7 @@ def load( h5_path, primary_labels=sio.load_slp(primary_path), lateral_labels=sio.load_slp(lateral_path), + video=sio.Video.from_filename(h5_path), ) @property @@ -59,11 +62,6 @@ def series_name(self) -> str: """Name of the series derived from the HDF5 filename.""" return Path(self.h5_path).name.split(".")[0] - @property - def video(self) -> sio.Video: - """The `sleap.Video` corresponding to the image series.""" - return self.primary_labels.video - def __len__(self) -> int: """Length of the series (number of images).""" return len(self.video) @@ -84,8 +82,8 @@ def get_frame(self, frame_idx: int) -> Tuple[sio.LabeledFrame, sio.LabeledFrame] frame_idx: Integer frame number. Returns: - Tuple of (primary_lf, lateral_lf) corresponding to the - `sleap.LabeledFrame` from each set of predictions on the same frame. + Tuple of (primary_lf, lateral_lf) corresponding to the `sio.LabeledFrame` + from each set of predictions on the same frame. """ lf_primary = self.primary_labels.find( self.primary_labels.video, frame_idx, return_new=True @@ -112,8 +110,10 @@ def get_primary_points(self, frame_idx: int) -> np.ndarray: """Get primary root points. Args: - frame_idx: frame index to get primary root points in shape (# instance, - # node, 2) + frame_idx: Frame index. + + Returns: + Primary root points as array of shape `(n_instances, n_nodes, 2)`. """ primary_lf, lateral_lf = self.get_frame(frame_idx) gt_instances_pr = primary_lf.user_instances + primary_lf.unused_predictions @@ -127,8 +127,10 @@ def get_lateral_points(self, frame_idx: int) -> np.ndarray: """Get lateral root points. Args: - frame_idx: frame index to get lateral root points in shape (# instance, - # node, 2) + frame_idx: Frame index. + + Returns: + Lateral root points as array of shape `(n_instances, n_nodes, 2)`. """ primary_lf, lateral_lf = self.get_frame(frame_idx) gt_instances_lr = lateral_lf.user_instances + lateral_lf.unused_predictions diff --git a/sleap_roots/summary.py b/sleap_roots/summary.py index d596814..de9c890 100644 --- a/sleap_roots/summary.py +++ b/sleap_roots/summary.py @@ -1,39 +1,54 @@ """Get summary of the traits.""" import numpy as np -from typing import Tuple +from typing import Dict, Optional + +SUMMARY_SUFFIXES = ["min", "max", "mean", "median", "std", "p5", "p25", "p75", "p95"] def get_summary( - trait: np.ndarray, -) -> Tuple[float, float, float, float, float, float, float, float, float]: - """Get summary of traits. + X: np.ndarray, + prefix: Optional[str] = None, +) -> Dict[str, float]: + """Get summary of a vector of observations. Args: - traits: Vector of trait values as a numpy array of shape (n,). + X: Vector of values as a numpy array of shape `(n,)`. + prefix: Prefix of the variable name. If not `None`, this string will be appended + to the key names of the returned dictionary. Returns: - A tuple of 9 scalar statistical summary measures: - min, max, mean, median, standard deviation - percentiles: 5, 25, 75, 95 + A dictionary of summary statistics of the input vector with keys: + "min", "max", "mean", "median", "std", "p5", "p25", "p75", "p95" + + If `prefix` was specified, the keys will be prefixed with the string. """ - trait_min = np.nanmin(trait) - trait_max = np.nanmax(trait) - trait_mean = np.nanmean(trait) - trait_median = np.nanmedian(trait) - trait_std = np.nanstd(trait) - trait_prc5 = np.nanpercentile(trait, 5) - trait_prc25 = np.nanpercentile(trait, 25) - trait_prc75 = np.nanpercentile(trait, 75) - trait_prc95 = np.nanpercentile(trait, 95) - return ( - trait_min, - trait_max, - trait_mean, - trait_median, - trait_std, - trait_prc5, - trait_prc25, - trait_prc75, - trait_prc95, - ) + if prefix is None: + prefix = "" + + X = np.atleast_1d(X) + + if len(X) == 0: + return { + f"{prefix}min": np.nan, + f"{prefix}max": np.nan, + f"{prefix}mean": np.nan, + f"{prefix}median": np.nan, + f"{prefix}std": np.nan, + f"{prefix}p5": np.nan, + f"{prefix}p25": np.nan, + f"{prefix}p75": np.nan, + f"{prefix}p95": np.nan, + } + else: + return { + f"{prefix}min": np.nanmin(X), + f"{prefix}max": np.nanmax(X), + f"{prefix}mean": np.nanmean(X), + f"{prefix}median": np.nanmedian(X), + f"{prefix}std": np.nanstd(X), + f"{prefix}p5": np.nanpercentile(X, 5), + f"{prefix}p25": np.nanpercentile(X, 25), + f"{prefix}p75": np.nanpercentile(X, 75), + f"{prefix}p95": np.nanpercentile(X, 95), + } diff --git a/sleap_roots/tips.py b/sleap_roots/tips.py index da91131..1602e22 100644 --- a/sleap_roots/tips.py +++ b/sleap_roots/tips.py @@ -3,61 +3,85 @@ import numpy as np -def get_tips(pts): - """Return tips (last node) from each lateral root. +def get_tips(pts: np.ndarray) -> np.ndarray: + """Return tips (last node) from each root. Args: - pts: Root landmarks as array of shape (instances, nodes, 2) + pts: Root landmarks as array of shape `(instances, nodes, 2)` or `(nodes, 2)`. Returns: - Array of tips (instances, (x, y)). - If there is no root, or the roots don't have tips an array of shape - (instances, 2) of NaNs will be returned. + Array of tips. If the input is `(nodes, 2)`, an array of shape `(2,)` will be + returned. If the input is `(instances, nodes, 2)`, an array of shape + `(instances, 2)` will be returned. If there is no root, or the roots don't have + tips, an array of shape `(instances, 2)` of NaNs will be returned. """ - # Get the last point of each instance. Shape is (instances, 2) - tip_pts = pts[:, -1] - return tip_pts - + # If the input has shape `(nodes, 2)`, reshape it for consistency + if pts.ndim == 2: + pts = pts[np.newaxis, ...] -def get_primary_depth(pts: np.ndarray) -> np.ndarray: - """Get primary root tip depth. + # Get the last point of each instance + tip_pts = pts[:, -1] # Shape is `(instances, 2)` - Args: - pts: primary root landmarks as array of shape (1, point, 2) + # If the input was `(nodes, 2)`, return an array of shape `(2,)` instead of `(1, 2)` + if tip_pts.shape[0] == 1: + return tip_pts[0] - Returns: - Primary root tip depth (location in y-axis). - """ - # get the last point of primary root, if invisible, return nan - if pts[:, -1].any() == np.nan: - return np.nan - else: - return pts[:, -1, 1] + return tip_pts def get_tip_xs(pts: np.ndarray) -> np.ndarray: - """Get x coordinations of tip points. + """Get x coordinates of tip points. Args: - pts: root landmarks as array of shape (instance, point, 2) + pts: Root landmarks as array of shape (instances, nodes, 2) or tips + (instances, 2). Return: - An array of tips in x axis (instance,). + An array of tip x-coordinates (instances,). """ - _tip_pts = get_tips(pts) + if pts.ndim not in (2, 3): + raise ValueError( + "Input array must be 2-dimensional (n_tips, 2) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + if pts.ndim == 3: + _tip_pts = get_tips( + pts + ) # Assuming get_tips returns an array of shape (instance, 2) + else: + _tip_pts = pts + + if _tip_pts.ndim != 2 or _tip_pts.shape[1] != 2: + raise ValueError( + "Array of tip points must be 2-dimensional with shape (instances, 2)." + ) + tip_xs = _tip_pts[:, 0] return tip_xs -def get_tip_ys(pts: np.ndarray) -> np.ndarray: - """Get y coordinations of tip points. +def get_tip_ys(tip_pts: np.ndarray) -> np.ndarray: + """Get y coordinates of tip points. Args: - pts: root landmarks as array of shape (instance, point, 2) + tip_pts: Root tips as array of shape `(instances, 2)` or `(2)` + when there is only one tip. Return: - An array of tips in y axis (instance,) + An array of the y-coordinates of tips (instances,). """ - _tip_pts = get_tips(pts) - tip_ys = _tip_pts[:, 1] + # If the input is a single number (float or integer), raise an error + if isinstance(tip_pts, (np.floating, float, np.integer, int)): + raise ValueError("Input must be an array of shape `(instances, 2)` or `(2, )`.") + + # Check for the 2D shape of the input array + if tip_pts.ndim == 1: + # If shape is `(2,)`, then reshape it to `(1, 2)` for consistency + tip_pts = tip_pts.reshape(1, 2) + elif tip_pts.ndim != 2: + raise ValueError("Input array must be of shape `(instances, 2)` or `(2, )`.") + + # At this point, `tip_pts` should be of shape `(instances, 2)`. + tip_ys = tip_pts[:, 1] return tip_ys diff --git a/sleap_roots/trait_pipelines.py b/sleap_roots/trait_pipelines.py new file mode 100644 index 0000000..02a748e --- /dev/null +++ b/sleap_roots/trait_pipelines.py @@ -0,0 +1,899 @@ +"""Extract traits in a pipeline based on a trait graph.""" + +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import attrs +import networkx as nx +import numpy as np +import pandas as pd + +from sleap_roots.angle import get_node_ind, get_root_angle +from sleap_roots.bases import ( + get_base_ct_density, + get_base_length, + get_base_length_ratio, + get_base_median_ratio, + get_base_tip_dist, + get_base_xs, + get_base_ys, + get_bases, + get_lateral_count, + get_root_pair_widths_projections, +) +from sleap_roots.convhull import ( + get_chull_area, + get_chull_line_lengths, + get_chull_max_height, + get_chull_max_width, + get_chull_perimeter, + get_convhull, +) +from sleap_roots.ellipse import ( + fit_ellipse, + get_ellipse_a, + get_ellipse_b, + get_ellipse_ratio, +) +from sleap_roots.lengths import get_grav_index, get_max_length_pts, get_root_lengths +from sleap_roots.networklength import ( + get_bbox, + get_network_distribution, + get_network_distribution_ratio, + get_network_length, + get_network_solidity, + get_network_width_depth_ratio, +) +from sleap_roots.points import get_all_pts_array +from sleap_roots.scanline import ( + count_scanline_intersections, + get_scanline_first_ind, + get_scanline_last_ind, +) +from sleap_roots.series import Series +from sleap_roots.summary import SUMMARY_SUFFIXES, get_summary +from sleap_roots.tips import get_tip_xs, get_tip_ys, get_tips + +warnings.filterwarnings( + "ignore", + message="invalid value encountered in intersection", + category=RuntimeWarning, + module="shapely", +) +warnings.filterwarnings( + "ignore", message="All-NaN slice encountered", category=RuntimeWarning +) +warnings.filterwarnings( + "ignore", message="All-NaN axis encountered", category=RuntimeWarning +) +warnings.filterwarnings( + "ignore", + message="Degrees of freedom <= 0 for slice.", + category=RuntimeWarning, + module="numpy", +) +warnings.filterwarnings( + "ignore", message="Mean of empty slice", category=RuntimeWarning +) +warnings.filterwarnings( + "ignore", + message="invalid value encountered in sqrt", + category=RuntimeWarning, + module="skimage", +) +warnings.filterwarnings( + "ignore", + message="invalid value encountered in double_scalars", + category=RuntimeWarning, +) +warnings.filterwarnings( + "ignore", + message="invalid value encountered in scalar divide", + category=RuntimeWarning, + module="ellipse", +) + + +@attrs.define +class TraitDef: + """Definition of how to compute a trait. + + Attributes: + name: Unique identifier for the trait. + fn: Function used to compute the trait's value. + input_traits: List of trait names that should be computed before the current + trait and are expected as input positional arguments to `fn`. + scalar: Indicates if the trait is scalar (has a dimension of 0 per frame). If + `True`, the trait is also listed in `SCALAR_TRAITS`. + include_in_csv: `True `indicates the trait should be included in downstream CSV + files. + kwargs: Additional keyword arguments to be passed to the `fn` function. These + arguments are not reused from previously computed traits. + description: String describing the trait for documentation purposes. + + Notes: + The `fn` specified will be called with a pattern like: + + ``` + trait_def = TraitDef( + name="my_trait", + fn=compute_my_trait, + input_traits=["input_trait_1", "input_trait_2"], + scalar=True, + include_in_csv=True, + kwargs={"kwarg1": True} + ) + traits[trait_def.name] = trait_def.fn( + *[traits[input_trait] for input_trait in trait_def.input_traits], + **trait_def.kwargs + ) + ``` + + For this example, the last line is equivalent to: + + ``` + traits["my_trait"] = trait_def.fn( + traits["input_trait_1"], traits["input_trait_2"], + kwarg1=True + ) + ``` + """ + + name: str + fn: Callable + input_traits: List[str] + scalar: bool + include_in_csv: bool + kwargs: Dict[str, Any] = attrs.field(factory=dict) + description: Optional[str] = None + + +@attrs.define +class Pipeline: + """Pipeline for computing traits. + + Attributes: + traits: List of `TraitDef` objects. + trait_map: Dictionary mapping trait names to their definitions. + trait_computation_order: List of trait names in the order they should be + computed. + """ + + traits: List[TraitDef] = attrs.field(init=False) + trait_map: Dict[str, TraitDef] = attrs.field(init=False) + trait_computation_order: List[str] = attrs.field(init=False) + + def __attrs_post_init__(self): + """Build pipeline objects from traits list.""" + # Build list of trait definitions. + self.traits = self.define_traits() + + # Check that trait names are unique. + trait_names = [trait.name for trait in self.traits] + if len(trait_names) != len(set(trait_names)): + raise ValueError("Trait names must be unique.") + + # Map trait names to their definitions. + self.trait_map = {trait_def.name: trait_def for trait_def in self.traits} + + # Determine computation order by topologically sorting the nodes. + self.trait_computation_order = self.get_computation_order() + + def define_traits(self) -> List[TraitDef]: + """Return list of `TraitDef` objects.""" + raise NotImplementedError + + def get_computation_order(self) -> List[str]: + """Determine computation order by topologically sorting the nodes. + + Returns: + A list of trait names in the order they should be computed. + """ + # Infer edges from trait map. + edges = [] + for trait_def in self.traits: + for input_trait in trait_def.input_traits: + edges.append((input_trait, trait_def.name)) + + # Build networkx graph from inferred edges. + G = nx.DiGraph() + G.add_edges_from(edges) + + # Determine computation order by topologically sorting the nodes. + trait_computation_order = list(nx.topological_sort(G)) + + return trait_computation_order + + @property + def summary_traits(self) -> List[str]: + """List of traits to include in the summary CSV.""" + return [ + trait.name + for trait in self.traits + if trait.include_in_csv and not trait.scalar + ] + + @property + def csv_traits(self) -> List[str]: + """List of frame-level traits to include in the CSV.""" + csv_traits = [] + for trait in self.traits: + if trait.include_in_csv: + if trait.scalar: + csv_traits.append(trait.name) + else: + csv_traits.extend( + [f"{trait.name}_{suffix}" for suffix in SUMMARY_SUFFIXES] + ) + return csv_traits + + def compute_frame_traits(self, traits: Dict[str, Any]) -> Dict[str, Any]: + """Compute traits based on the pipeline. + + Args: + traits: Dictionary of traits where keys are trait names and values are + the trait values. + + Returns: + A dictionary of computed traits. + """ + # Initialize traits container with initial data. + traits = traits.copy() + + # Compute traits! + for trait_name in self.trait_computation_order: + if trait_name in traits: + # Skip traits already computed. + continue + + # Get trait definition. + trait_def = self.trait_map[trait_name] + + # Compute trait based on trait definition. + traits[trait_name] = trait_def.fn( + *[traits[input_trait] for input_trait in trait_def.input_traits], + **trait_def.kwargs, + ) + + return traits + + def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]: + """Return initial traits for a plant frame. + + Args: + plant: The plant `Series` object. + frame_idx: The index of the current frame. + + Returns: + A dictionary of initial traits. + + This is defined on a per-pipeline basis as different plant species will have + different initial points to be used as starting traits. + + Most commonly, this will be the primary and lateral root points for the + current frame. + """ + raise NotImplementedError + + def compute_plant_traits( + self, + plant: Series, + write_csv: bool = False, + csv_suffix: str = ".traits.csv", + return_non_scalar: bool = False, + ) -> pd.DataFrame: + """Compute traits for a plant. + + Args: + plant: The plant image series as a `Series` object. + write_csv: A boolean value. If True, it writes per plant detailed + CSVs with traits for every instance on every frame. + csv_suffix: If `write_csv` is `True`, a CSV file will be saved with the same + name as the plant's `{plant.series_name}{csv_suffix}`. + return_non_scalar: If `True`, return all non-scalar traits as well as the + summarized traits. + + Returns: + The computed traits as a pandas DataFrame. + """ + traits = [] + for frame in range(len(plant)): + # Get initial traits for the frame. + initial_traits = self.get_initial_frame_traits(plant, frame) + + # Compute traits via the frame-level pipeline. + frame_traits = self.compute_frame_traits(initial_traits) + + # Compute trait summaries. + for trait_name in self.summary_traits: + trait_summary = get_summary( + frame_traits[trait_name], prefix=f"{trait_name}_" + ) + frame_traits.update(trait_summary) + + # Add metadata. + frame_traits["plant_name"] = plant.series_name + frame_traits["frame_idx"] = frame + traits.append(frame_traits) + traits = pd.DataFrame(traits) + + # Move metadata columns to the front. + plant_name = traits.pop("plant_name") + frame_idx = traits.pop("frame_idx") + traits = pd.concat([plant_name, frame_idx, traits], axis=1) + + if write_csv: + csv_name = Path(plant.h5_path).with_suffix(csv_suffix) + traits[["plant_name", "frame_idx"] + self.csv_traits].to_csv( + csv_name, index=False + ) + + if return_non_scalar: + return traits + else: + return traits[["plant_name", "frame_idx"] + self.csv_traits] + + def compute_batch_traits( + self, + plants: List[Series], + write_csv: bool = False, + csv_path: str = "traits.csv", + ) -> pd.DataFrame: + """Compute traits for a batch of plants. + + Args: + plants: List of `Series` objects. + write_csv: If `True`, write the computed traits to a CSV file. + csv_path: Path to write the CSV file to. + + Returns: + A pandas DataFrame of computed traits summarized over all frames of each + plant. The resulting dataframe will have a row for each plant and a column + for each plant-level summarized trait. + + Summarized traits are prefixed with the trait name and an underscore, + followed by the summary statistic. + """ + all_traits = [] + for plant in plants: + # Compute frame level traits for the plant. + plant_traits = self.compute_plant_traits(plant) + + # Summarize frame level traits. + plant_summary = {"plant_name": plant.series_name} + for trait_name in self.csv_traits: + trait_summary = get_summary( + plant_traits[trait_name], prefix=f"{trait_name}_" + ) + plant_summary.update(trait_summary) + all_traits.append(plant_summary) + + # Build dataframe from list of frame-level summaries. + all_traits = pd.DataFrame(all_traits) + + if write_csv: + all_traits.to_csv(csv_path, index=False) + return all_traits + + +@attrs.define +class DicotPipeline(Pipeline): + """Pipeline for computing traits for dicot plants. + + Attributes: + img_height: Image height. + img_width: Image width. + root_width_tolerance: Difference in projection norm between right and left side. + n_scanlines: Number of scan lines, np.nan for no interaction. + network_fraction: Length found in the lower fraction value of the network. + """ + + img_height: int = 1080 + img_width: int = 2048 + root_width_tolerance: float = 0.02 + n_scanlines: int = 50 + network_fraction: float = 2 / 3 + + def define_traits(self) -> List[TraitDef]: + """Define the trait computation pipeline for dicot plants.""" + trait_definitions = [ + TraitDef( + name="primary_max_length_pts", + fn=get_max_length_pts, + input_traits=["primary_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Points of the primary root with maximum length.", + ), + TraitDef( + name="pts_all_array", + fn=get_all_pts_array, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": False}, + description="Landmark points within a given frame as a flat array" + "of coordinates.", + ), + TraitDef( + name="root_widths", + fn=get_root_pair_widths_projections, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=True, + kwargs={"tolerance": self.root_width_tolerance, "monocots": False}, + description="Return estimation of root width using bases of lateral " + "roots.", + ), + TraitDef( + name="lateral_count", + fn=get_lateral_count, + input_traits=["lateral_pts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Get the number of lateral roots.", + ), + TraitDef( + name="lateral_proximal_node_inds", + fn=get_node_ind, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"proximal": True}, + description="Get the indices of the proximal nodes of lateral roots.", + ), + TraitDef( + name="lateral_distal_node_inds", + fn=get_node_ind, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"proximal": False}, + description="Get the indices of the distal nodes of lateral roots.", + ), + TraitDef( + name="lateral_lengths", + fn=get_root_lengths, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of lateral root lengths of shape `(instances,)`.", + ), + TraitDef( + name="lateral_base_pts", + fn=get_bases, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": False}, + description="Array of lateral bases `(instances, (x, y))`.", + ), + TraitDef( + name="lateral_tip_pts", + fn=get_tips, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Array of lateral tips `(instances, (x, y))`.", + ), + TraitDef( + name="scanline_intersection_counts", + fn=count_scanline_intersections, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=True, + kwargs={ + "height": self.img_height, + "width": self.img_width, + "n_line": self.n_scanlines, + "monocots": False, + }, + description="Array of intersections of each scanline `(n_scanlines,)`.", + ), + TraitDef( + name="lateral_angles_distal", + fn=get_root_angle, + input_traits=["lateral_pts", "lateral_distal_node_inds"], + scalar=False, + include_in_csv=True, + kwargs={"proximal": False, "base_ind": 0}, + description="Array of lateral distal angles in degrees `(instances,)`.", + ), + TraitDef( + name="lateral_angles_proximal", + fn=get_root_angle, + input_traits=["lateral_pts", "lateral_proximal_node_inds"], + scalar=False, + include_in_csv=True, + kwargs={"proximal": True, "base_ind": 0}, + description="Array of lateral proximal angles in degrees " + "`(instances,)`.", + ), + TraitDef( + name="network_solidity", + fn=get_network_solidity, + input_traits=["network_length", "chull_area"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the total network length divided by the network" + "convex area.", + ), + TraitDef( + name="ellipse", + fn=fit_ellipse, + input_traits=["pts_all_array"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of (a, b, ratio) containing the semi-major axis " + "length, semi-minor axis length, and the ratio of the major to minor " + "lengths.", + ), + TraitDef( + name="bounding_box", + fn=get_bbox, + input_traits=["pts_all_array"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of four parameters in bounding box.", + ), + TraitDef( + name="convex_hull", + fn=get_convhull, + input_traits=["pts_all_array"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Convex hull of the points.", + ), + TraitDef( + name="primary_proximal_node_ind", + fn=get_node_ind, + input_traits=["primary_max_length_pts"], + scalar=True, + include_in_csv=False, + kwargs={"proximal": True}, + description="Get the indices of the proximal nodes of primary roots.", + ), + TraitDef( + name="primary_angle_proximal", + fn=get_root_angle, + input_traits=["primary_max_length_pts", "primary_proximal_node_ind"], + scalar=True, + include_in_csv=True, + kwargs={"proximal": True, "base_ind": 0}, + description="Array of primary proximal angles in degrees " + "`(instances,)`.", + ), + TraitDef( + name="primary_distal_node_ind", + fn=get_node_ind, + input_traits=["primary_max_length_pts"], + scalar=True, + include_in_csv=False, + kwargs={"proximal": False}, + description="Get the indices of the distal nodes of primary roots.", + ), + TraitDef( + name="primary_angle_distal", + fn=get_root_angle, + input_traits=["primary_max_length_pts", "primary_distal_node_ind"], + scalar=True, + include_in_csv=True, + kwargs={"proximal": False, "base_ind": 0}, + description="Array of primary distal angles in degrees `(instances,)`.", + ), + TraitDef( + name="primary_length", + fn=get_root_lengths, + input_traits=["primary_max_length_pts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of primary root length.", + ), + TraitDef( + name="primary_base_pt", + fn=get_bases, + input_traits=["primary_max_length_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": False}, + description="Primary root base point.", + ), + TraitDef( + name="primary_tip_pt", + fn=get_tips, + input_traits=["primary_max_length_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Primary root tip point.", + ), + TraitDef( + name="network_length_lower", + fn=get_network_distribution, + input_traits=["primary_max_length_pts", "lateral_pts", "bounding_box"], + scalar=True, + include_in_csv=True, + kwargs={"fraction": self.network_fraction, "monocots": False}, + description="Scalar of the root network length in the lower fraction " + "of the plant.", + ), + TraitDef( + name="lateral_base_xs", + fn=get_base_xs, + input_traits=["lateral_base_pts"], + scalar=False, + include_in_csv=True, + kwargs={"monocots": False}, + description="Array of the x-coordinates of lateral bases " + "`(instances,)`.", + ), + TraitDef( + name="lateral_base_ys", + fn=get_base_ys, + input_traits=["lateral_base_pts"], + scalar=False, + include_in_csv=True, + kwargs={"monocots": False}, + description="Array of the y-coordinates of lateral bases " + "`(instances,)`.", + ), + TraitDef( + name="base_ct_density", + fn=get_base_ct_density, + input_traits=["primary_length", "lateral_base_pts"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": False}, + description="Scalar of base count density.", + ), + TraitDef( + name="lateral_tip_xs", + fn=get_tip_xs, + input_traits=["lateral_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of the x-coordinates of lateral tips `(instance,)`.", + ), + TraitDef( + name="lateral_tip_ys", + fn=get_tip_ys, + input_traits=["lateral_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of the y-coordinates of lateral tips `(instance,)`.", + ), + TraitDef( + name="network_distribution_ratio", + fn=get_network_distribution_ratio, + input_traits=[ + "primary_length", + "lateral_lengths", + "network_length_lower", + ], + scalar=True, + include_in_csv=True, + kwargs={"fraction": self.network_fraction, "monocots": False}, + description="Scalar of ratio of the root network length in the lower " + "fraction of the plant over all root length.", + ), + TraitDef( + name="network_length", + fn=get_network_length, + input_traits=["primary_length", "lateral_lengths"], + scalar=True, + include_in_csv=False, + kwargs={"monocots": False}, + description="Scalar of all roots network length.", + ), + TraitDef( + name="primary_base_pt_y", + fn=get_base_ys, + input_traits=["primary_base_pt"], + scalar=True, + include_in_csv=False, + kwargs={"monocots": False}, + description="Y-coordinate of the primary root base node.", + ), + TraitDef( + name="primary_tip_pt_y", + fn=get_tip_ys, + input_traits=["primary_tip_pt"], + scalar=True, + include_in_csv=False, + kwargs={}, + description="Y-coordinate of the primary root tip node.", + ), + TraitDef( + name="ellipse_a", + fn=get_ellipse_a, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of semi-major axis length.", + ), + TraitDef( + name="ellipse_b", + fn=get_ellipse_b, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of semi-minor axis length.", + ), + TraitDef( + name="network_width_depth_ratio", + fn=get_network_width_depth_ratio, + input_traits=["bounding_box"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of bounding box width to depth ratio of root " + "network.", + ), + TraitDef( + name="chull_perimeter", + fn=get_chull_perimeter, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull perimeter.", + ), + TraitDef( + name="chull_area", + fn=get_chull_area, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull area.", + ), + TraitDef( + name="chull_max_width", + fn=get_chull_max_width, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull maximum width.", + ), + TraitDef( + name="chull_max_height", + fn=get_chull_max_height, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull maximum height.", + ), + TraitDef( + name="chull_line_lengths", + fn=get_chull_line_lengths, + input_traits=["convex_hull"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of line lengths connecting any two vertices on the" + "convex hull.", + ), + TraitDef( + name="base_length", + fn=get_base_length, + input_traits=["lateral_base_ys"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the distance between the top and deepest base" + "y-coordinates.", + ), + TraitDef( + name="base_median_ratio", + fn=get_base_median_ratio, + input_traits=["lateral_base_ys", "primary_tip_pt_y"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": False}, + description="Scalar of base median ratio.", + ), + TraitDef( + name="grav_index", + fn=get_grav_index, + input_traits=["primary_length", "primary_base_tip_dist"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of primary root gravity index.", + ), + TraitDef( + name="base_length_ratio", + fn=get_base_length_ratio, + input_traits=["primary_length", "base_length"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": False}, + description="Scalar of base length ratio.", + ), + TraitDef( + name="primary_base_tip_dist", + fn=get_base_tip_dist, + input_traits=["primary_base_pt", "primary_tip_pt"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of distance from primary root base to tip.", + ), + TraitDef( + name="ellipse_ratio", + fn=get_ellipse_ratio, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of ratio of the minor to major lengths.", + ), + TraitDef( + name="scanline_last_ind", + fn=get_scanline_last_ind, + input_traits=["scanline_intersection_counts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of count_scanline_interaction index for the last" + "interaction.", + ), + TraitDef( + name="scanline_first_ind", + fn=get_scanline_first_ind, + input_traits=["scanline_intersection_counts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of count_scanline_interaction index for the first" + "interaction.", + ), + ] + + return trait_definitions + + def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]: + """Return initial traits for a plant frame. + + Args: + plant: The plant `Series` object. + frame_idx: The index of the current frame. + + Returns: + A dictionary of initial traits with keys: + - "primary_pts": Array of primary root points. + - "lateral_pts": Array of lateral root points. + """ + # Get the root instances. + primary, lateral = plant[frame_idx] + gt_instances_pr = primary.user_instances + primary.unused_predictions + gt_instances_lr = lateral.user_instances + lateral.unused_predictions + + # Convert the instances to numpy arrays. + if len(gt_instances_lr) == 0: + lateral_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + else: + lateral_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) + + if len(gt_instances_pr) == 0: + primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + else: + primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) + + return {"primary_pts": primary_pts, "lateral_pts": lateral_pts} diff --git a/sleap_roots/traitsgraph.py b/sleap_roots/traitsgraph.py deleted file mode 100644 index cf2d9d3..0000000 --- a/sleap_roots/traitsgraph.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Create traits graph and get destinations.""" - -import networkx as nx - - -def get_traits_graph(): - """Get traits graph using networkx. - - Args: - None - - Return: - Destination nodes. - """ - G = nx.DiGraph() - edge_list = [ - ("pts", "primary_pts"), - ("pts", "lateral_pts"), - ("primary_pts", "primary_base_pt"), - ("primary_pts", "primary_angle_proximal"), - ("primary_pts", "primary_angle_distal"), - ("primary_pts", "primary_length"), - ("primary_base_pt", "primary_base_pt_y"), - ("primary_pts", "primary_tip_pt"), - ("primary_base_pt_y", "primary_base_tip_dist"), - ("primary_tip_pt_y", "primary_base_tip_dist"), - ("primary_tip_pt_y", "primary_depth"), - ("lateral_pts", "lateral_count"), - ("primary_tip_pt", "primary_tip_pt_y"), - ("primary_length_max", "grav_index"), - ("primary_base_tip_dist", "grav_index"), - ("lateral_base_ys", "base_length"), - ("base_length", "base_length_ratio"), - ("primary_length_max", "base_length_ratio"), - ("lateral_base_ys", "base_median_ratio"), - ("primary_tip_pt_y", "base_median_ratio"), - ("lateral_base_pts", "base_ct_density"), - ("primary_length", "base_ct_density"), - ("convex_hull", "chull_perimeter"), - ("convex_hull", "chull_area"), - ("convex_hull", "chull_max_width"), - ("convex_hull", "chull_max_height"), - ("primary_pts", "ellipse"), - ("lateral_pts", "ellipse"), - ("ellipse", "ellipse_a"), - ("ellipse", "ellipse_b"), - ("ellipse_a", "ellipse_ratio"), - ("ellipse_b", "ellipse_ratio"), - ("chull_area", "network_solidity"), - ("lateral_lengths", "network_solidity"), - ("primary_length", "network_solidity"), - ("lateral_pts", "bounding_box"), - ("primary_pts", "bounding_box"), - ("bounding_box", "network_width_depth_ratio"), - ("lateral_lengths", "network_distribution_ratio"), - ("primary_length", "network_distribution_ratio"), - ("primary_length", "network_length_lower"), - ("lateral_lengths", "network_length_lower"), - ("bounding_box", "network_distribution_ratio"), - ("bounding_box", "network_length_lower"), - ("scanline_intersection_counts", "scanline_last_ind"), - ("scanline_intersection_counts", "scanline_first_ind"), - ("lateral_pts", "lateral_angles_proximal"), - ("lateral_pts", "lateral_angles_distal"), - ("lateral_pts", "lateral_lengths"), - ("primary_pts", "stem_widths"), - ("lateral_pts", "lateral_base_pts"), - ("lateral_base_pts", "stem_widths"), - ("lateral_base_pts", "lateral_base_xs"), - ("lateral_base_pts", "lateral_base_ys"), - ("lateral_pts", "lateral_tip_pts"), - ("lateral_tip_pts", "lateral_tip_xs"), - ("lateral_tip_pts", "lateral_tip_ys"), - ("primary_pts", "convex_hull"), - ("lateral_pts", "convex_hull"), - ("convex_hull", "chull_line_lengths"), - ("primary_pts", "scanline_intersection_counts"), - ("lateral_pts", "scanline_intersection_counts"), - ] - - G.add_edges_from(edge_list) - dts = [dst for (src, dst) in list(nx.bfs_tree(G, "pts").edges())[2:]] - return dts diff --git a/tests/test_angle.py b/tests/test_angle.py index 7abf8d3..666e667 100644 --- a/tests/test_angle.py +++ b/tests/test_angle.py @@ -123,42 +123,56 @@ def test_get_node_ind(canola_h5): pts = primary.numpy() proximal = True node_ind = get_node_ind(pts, proximal) - np.testing.assert_array_equal(node_ind, [1]) + np.testing.assert_array_equal(node_ind, 1) # test get_node_ind function using root that without second node def test_get_node_ind_nan2(pts_nan2): proximal = True node_ind = get_node_ind(pts_nan2, proximal) - np.testing.assert_array_equal(node_ind, [2]) + np.testing.assert_array_equal(node_ind, 2) # test get_node_ind function using root that without second and third nodes def test_get_node_ind_nan3(pts_nan3): proximal = True node_ind = get_node_ind(pts_nan3, proximal) - np.testing.assert_array_equal(node_ind, [3]) + np.testing.assert_array_equal(node_ind, 0) # test get_node_ind function using two roots/instances def test_get_node_ind_nan32(pts_nan32): proximal = True node_ind = get_node_ind(pts_nan32, proximal) - np.testing.assert_array_equal(node_ind, [3, 2]) + np.testing.assert_array_equal(node_ind, [0, 2]) # test get_node_ind function using root that without last node def test_get_node_ind_nan6(pts_nan6): proximal = False node_ind = get_node_ind(pts_nan6, proximal) - np.testing.assert_array_equal(node_ind, [4]) + np.testing.assert_array_equal(node_ind, 4) # test get_node_ind function using root with all nan node def test_get_node_ind_nanall(pts_nanall): proximal = False node_ind = get_node_ind(pts_nanall, proximal) - np.testing.assert_array_equal(node_ind, [0]) + np.testing.assert_array_equal(node_ind, np.nan) + + +# test get_node_ind function using root with pts_nan32_5node +def test_get_node_ind_5node(pts_nan32_5node): + proximal = False + node_ind = get_node_ind(pts_nan32_5node, proximal) + np.testing.assert_array_equal(node_ind, [4, 4]) + + +# test get_node_ind function (proximal) using root with pts_nan32_5node +def test_get_node_ind_5node_proximal(pts_nan32_5node): + proximal = True + node_ind = get_node_ind(pts_nan32_5node, proximal) + np.testing.assert_array_equal(node_ind, [0, 2]) # test canola get_root_angle function (base node to distal node angle) @@ -169,8 +183,8 @@ def test_get_root_angle_distal(canola_h5): primary, lateral = series[0] pts = primary.numpy() proximal = False - angs = get_root_angle(pts, proximal) - assert angs.shape == (1,) + node_ind = get_node_ind(pts, proximal) + angs = get_root_angle(pts, node_ind, proximal) assert pts.shape == (1, 6, 2) np.testing.assert_almost_equal(angs, 7.7511306, decimal=3) @@ -183,7 +197,8 @@ def test_get_root_angle_proximal_rice(rice_h5): primary, lateral = series[0] pts = primary.numpy() proximal = True - angs = get_root_angle(pts, proximal) + node_ind = get_node_ind(pts, proximal) + angs = get_root_angle(pts, node_ind, proximal) assert angs.shape == (2,) assert pts.shape == (2, 6, 2) np.testing.assert_almost_equal(angs, [17.3180819, 3.2692877], decimal=3) @@ -192,7 +207,8 @@ def test_get_root_angle_proximal_rice(rice_h5): # test get_root_angle function using two roots/instances (base node to proximal node angle) def test_get_root_angle_proximal(pts_nan32): proximal = True - angs = get_root_angle(pts_nan32, proximal) + node_ind = get_node_ind(pts_nan32, proximal) + angs = get_root_angle(pts_nan32, node_ind, proximal) assert angs.shape == (2,) np.testing.assert_almost_equal(angs, [np.nan, 1.7291381], decimal=3) @@ -200,14 +216,15 @@ def test_get_root_angle_proximal(pts_nan32): # test get_root_angle function using two roots/instances (base node to proximal node angle) def test_get_root_angle_proximal_5node(pts_nan32_5node): proximal = True - angs = get_root_angle(pts_nan32_5node, proximal) + node_ind = get_node_ind(pts_nan32_5node, proximal) + angs = get_root_angle(pts_nan32_5node, node_ind, proximal) assert angs.shape == (2,) np.testing.assert_almost_equal(angs, [np.nan, 2.3339111], decimal=3) # test get_root_angle function using root/instance with all nan value -def test_get_root_angle_proximal_5node(pts_nanall): +def test_get_root_angle_proximal_allnan(pts_nanall): proximal = True - angs = get_root_angle(pts_nanall, proximal) - assert angs.shape == (1,) + node_ind = get_node_ind(pts_nanall, proximal) + angs = get_root_angle(pts_nanall, node_ind, proximal) np.testing.assert_almost_equal(angs, np.nan, decimal=3) diff --git a/tests/test_bases.py b/tests/test_bases.py index 614188e..f5c2f82 100644 --- a/tests/test_bases.py +++ b/tests/test_bases.py @@ -2,18 +2,15 @@ get_bases, get_base_ct_density, get_base_tip_dist, - get_grav_index, get_lateral_count, - get_root_lengths, - get_root_lengths_max, get_base_xs, get_base_ys, get_base_length, get_base_length_ratio, - get_primary_depth, get_root_pair_widths_projections, ) -from sleap_roots.points import get_lateral_pts +from sleap_roots.lengths import get_max_length_pts, get_root_lengths_max +from sleap_roots.tips import get_tips from sleap_roots import Series import numpy as np import pytest @@ -180,97 +177,44 @@ def test_bases_no_roots(pts_no_roots): # test get_base_tip_dist with standard points def test_get_base_tip_dist_standard(pts_standard): - distance = get_base_tip_dist(pts_standard) + primary_pts = pts_standard + primary_base_pt = get_bases(primary_pts) + primary_tip_pt = get_tips(primary_pts) + distance = get_base_tip_dist(primary_base_pt, primary_tip_pt) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [2.82842712, 2.82842712], decimal=7) # test get_base_tip_dist with roots without bases def test_get_base_tip_dist_no_bases(pts_no_bases): - distance = get_base_tip_dist(pts_no_bases) + primary_pts = pts_no_bases + primary_base_pt = get_bases(primary_pts) + primary_tip_pt = get_tips(primary_pts) + distance = get_base_tip_dist(primary_base_pt, primary_tip_pt) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [np.nan, np.nan], decimal=7) # test get_base_tip_dist with roots with one base def test_get_base_tip_dist_one_base(pts_one_base): - distance = get_base_tip_dist(pts_one_base) + primary_pts = pts_one_base + primary_base_pt = get_bases(primary_pts) + primary_tip_pt = get_tips(primary_pts) + distance = get_base_tip_dist(primary_base_pt, primary_tip_pt) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [2.82842712, np.nan], decimal=7) # test get_base_tip_dist with no roots def test_get_base_tip_dist_no_roots(pts_no_roots): - distance = get_base_tip_dist(pts_no_roots) + primary_pts = pts_no_roots + primary_base_pt = get_bases(primary_pts) + primary_tip_pt = get_tips(primary_pts) + distance = get_base_tip_dist(primary_base_pt, primary_tip_pt) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [np.nan, np.nan], decimal=7) -# test get_grav_index function -def test_get_grav_index(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - grav_index = get_grav_index(pts) - np.testing.assert_almost_equal(grav_index, 0.08898137324716636) - - -def test_get_root_lengths(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - assert pts.shape == (1, 6, 2) - - root_lengths = get_root_lengths(pts) - assert root_lengths.shape == (1,) - np.testing.assert_array_almost_equal(root_lengths, [971.050417]) - - pts = lateral.numpy() - assert pts.shape == (5, 3, 2) - - root_lengths = get_root_lengths(pts) - assert root_lengths.shape == (5,) - np.testing.assert_array_almost_equal( - root_lengths, [20.129579, 62.782368, 80.268003, 34.925591, 3.89724] - ) - - -def test_get_root_lengths_no_roots(pts_no_bases): - root_lengths = get_root_lengths(pts_no_bases) - assert root_lengths.shape == (2,) - np.testing.assert_array_almost_equal(root_lengths, np.array([np.nan, np.nan])) - - -def test_get_root_lengths_one_point(pts_one_base): - root_lengths = get_root_lengths(pts_one_base) - assert root_lengths.shape == (2,) - np.testing.assert_array_almost_equal( - root_lengths, np.array([2.82842712475, np.nan]) - ) - - -# test get_root_lengths_max function with lengths_normal -def test_get_root_lengths_max_normal(lengths_normal): - max_length = get_root_lengths_max(lengths_normal) - np.testing.assert_array_almost_equal(max_length, 329.4) - - -# test get_root_lengths_max function with lengths_with_nan -def test_get_root_lengths_max_with_nan(lengths_with_nan): - max_length = get_root_lengths_max(lengths_with_nan) - np.testing.assert_array_almost_equal(max_length, 329.4) - - -# test get_root_lengths_max function with lengths_all_nan -def test_get_root_lengths_max_all_nan(lengths_all_nan): - max_length = get_root_lengths_max(lengths_all_nan) - np.testing.assert_array_almost_equal(max_length, np.nan) - - # test get_lateral_count function with canola def test_get_lateral_count(canola_h5): series = Series.load( @@ -288,12 +232,25 @@ def test_get_base_xs_canola(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - base_xs = get_base_xs(pts_lr, monocots) + lateral = plant[0][1] # first frame, lateral labels + lateral_pts = lateral.numpy() # lateral points as numpy array + base_xs = get_base_xs(lateral_pts, monocots) assert base_xs.shape[0] == 5 np.testing.assert_almost_equal(base_xs[1], 1112.5506591796875, decimal=3) +# test get_base_xs with rice +def test_get_base_xs_rice(rice_h5): + monocots = True + plant = Series.load( + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + lateral = plant[0][1] # first frame, lateral labels + lateral_pts = lateral.numpy() # lateral points as numpy array + base_xs = get_base_xs(lateral_pts, monocots) + assert np.isnan(base_xs) + + # test get_base_xs with pts_standard def test_get_base_xs_standard(pts_standard): base_xs = get_base_xs(pts_standard) @@ -315,15 +272,31 @@ def test_get_base_ys_canola(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - base_ys = get_base_ys(pts_lr, monocots) + lateral = plant[0][1] # first frame, lateral labels + lateral_pts = lateral.numpy() # lateral points as numpy array + base_pts = get_bases(lateral_pts) # get the bases of the lateral roots + base_ys = get_base_ys(base_pts, monocots) assert base_ys.shape[0] == 5 np.testing.assert_almost_equal(base_ys[1], 228.0966796875, decimal=3) +# test get_base_ys with rice +def test_get_base_ys_rice(rice_h5): + monocots = True + plant = Series.load( + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + lateral = plant[0][1] # first frame, lateral labels + lateral_pts = lateral.numpy() # lateral points as numpy array + base_pts = get_bases(lateral_pts, monocots) # get the bases of the lateral roots + base_ys = get_base_ys(base_pts, monocots) + assert np.isnan(base_ys) + + # test get_base_ys with pts_standard def test_get_base_ys_standard(pts_standard): - base_ys = get_base_ys(pts_standard) + bases = get_bases(pts_standard) + base_ys = get_base_ys(bases) assert base_ys.shape[0] == 2 np.testing.assert_almost_equal(base_ys[0], 2, decimal=3) np.testing.assert_almost_equal(base_ys[1], 6, decimal=3) @@ -331,7 +304,8 @@ def test_get_base_ys_standard(pts_standard): # test get_base_ys with pts_no_roots def test_get_base_ys_no_roots(pts_no_roots): - base_ys = get_base_ys(pts_no_roots) + bases = get_bases(pts_no_roots) + base_ys = get_base_ys(bases) assert base_ys.shape[0] == 2 np.testing.assert_almost_equal(base_ys[0], np.nan, decimal=3) @@ -341,26 +315,44 @@ def test_get_base_length_canola(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - base_length = get_base_length(pts_lr) + lateral = plant[0][1] # first frame, lateral labels + lateral_pts = lateral.numpy() # lateral points as numpy array + bases = get_bases(lateral_pts) # get bases of lateral roots + base_ys = get_base_ys(bases) # get y-coordinates of bases + base_length = get_base_length(base_ys) np.testing.assert_almost_equal(base_length, 83.69914245605469, decimal=3) +# test get_base_length with rice +def test_get_base_length_rice(rice_h5): + plant = Series.load( + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + lateral = plant[0][1] # first frame, lateral labels + lateral_pts = lateral.numpy() # lateral points as numpy array + base_length = get_base_length(lateral_pts, monocots=True) + assert np.isnan(base_length) + + # test get_base_length with pts_standard def test_get_base_length_standard(pts_standard): - base_length = get_base_length(pts_standard) + bases = get_bases(pts_standard) # get bases of lateral roots + base_ys = get_base_ys(bases) # get y-coordinates of bases + base_length = get_base_length(base_ys) np.testing.assert_almost_equal(base_length, 4, decimal=3) # test get_base_length with pts_no_roots def test_get_base_length_no_roots(pts_no_roots): base_length = get_base_length(pts_no_roots) - np.testing.assert_almost_equal(base_length, np.nan, decimal=3) + assert np.isnan(base_length) # test get_base_ct_density function with defined primary and lateral points def test_get_base_ct_density(primary_pts, lateral_pts): - base_ct_density = get_base_ct_density(primary_pts, lateral_pts) + primary_length_max = get_root_lengths_max(primary_pts) + lateral_base_pts = get_bases(lateral_pts) + base_ct_density = get_base_ct_density(primary_length_max, lateral_base_pts) np.testing.assert_almost_equal(base_ct_density, 0.00334, decimal=5) @@ -373,25 +365,23 @@ def test_get_base_ct_density_canola(canola_h5): primary, lateral = series[0] primary_pts = primary.numpy() lateral_pts = lateral.numpy() - base_ct_density = get_base_ct_density(primary_pts, lateral_pts, monocots) + primary_length_max = get_root_lengths_max(primary_pts) + lateral_base_pts = get_bases(lateral_pts) + base_ct_density = get_base_ct_density(primary_length_max, lateral_base_pts) np.testing.assert_almost_equal(base_ct_density, 0.004119, decimal=5) -# test get_primary_depth function with defined primary_pts -def test_get_primary_depth(primary_pts): - primary_depth = get_primary_depth(primary_pts) - np.testing.assert_almost_equal(primary_depth, 808.12585449, decimal=3) - - -# test get_primary_depth function with canola -def test_get_primary_depth(canola_h5): +# test get_base_ct_density function with rice example +def test_get_base_ct_density_rice(rice_h5): + monocots = True series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() - primary_depth = get_primary_depth(primary_pts) - np.testing.assert_almost_equal(primary_depth, 1020.9813842773438, decimal=3) + lateral_pts = lateral.numpy() + base_ct_density = get_base_ct_density(primary_pts, lateral_pts, monocots) + assert np.isnan(base_ct_density) # test get_base_length_ratio with canola @@ -402,19 +392,26 @@ def test_get_base_length_ratio(canola_h5): primary, lateral = series[0] primary_pts = primary.numpy() lateral_pts = lateral.numpy() - base_length_ratio = get_base_length_ratio(primary_pts, lateral_pts) + primary_length_max = get_root_lengths_max(primary_pts) + bases = get_bases(lateral_pts) + lateral_base_ys = get_base_ys(bases) + base_length = get_base_length(lateral_base_ys) + base_length_ratio = get_base_length_ratio(primary_length_max, base_length) np.testing.assert_almost_equal(base_length_ratio, 0.086, decimal=3) -def test_stem_width(canola_h5): +def test_root_width(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) lateral_pts = lateral.numpy() - assert primary_pts.shape == (1, 6, 2) + assert primary_max_length_pts.shape == (6, 2) assert lateral_pts.shape == (5, 3, 2) - stem_widths = get_root_pair_widths_projections(lateral_pts, primary_pts, 0.02) - np.testing.assert_array_almost_equal(stem_widths, [31.603239]) + root_widths = get_root_pair_widths_projections( + primary_max_length_pts, lateral_pts, 0.02 + ) + np.testing.assert_almost_equal(root_widths, np.array([31.60323909]), decimal=7) diff --git a/tests/test_convhull.py b/tests/test_convhull.py index ce4b907..90cb473 100644 --- a/tests/test_convhull.py +++ b/tests/test_convhull.py @@ -2,13 +2,13 @@ from sleap_roots import Series from sleap_roots.convhull import ( get_convhull, - get_convhull_features, get_chull_line_lengths, get_chull_area, get_chull_max_height, get_chull_max_width, get_chull_perimeter, ) +from sleap_roots.lengths import get_max_length_pts from sleap_roots.points import get_all_pts_array import numpy as np import pytest @@ -79,11 +79,11 @@ def test_get_convhull_canola(canola_h5): canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] - - primary_points = primary.numpy().reshape(-1, 2) - lateral_points = lateral.numpy().reshape(-1, 2) - convex_hull_points = np.concatenate((primary_points, lateral_points), axis=0) - convex_hull = get_convhull(convex_hull_points) + primary_pts = primary.numpy() + lateral_pts = lateral.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + pts = get_all_pts_array(primary_max_length_pts, lateral_pts) + convex_hull = get_convhull(pts) assert type(convex_hull) == ConvexHull @@ -93,76 +93,74 @@ def test_get_convhull_features_canola(canola_h5): canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] + primary_pts = primary.numpy() + lateral_pts = lateral.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + pts = get_all_pts_array(primary_max_length_pts, lateral_pts) + convex_hull = get_convhull(pts) - primary_points = primary.numpy().reshape(-1, 2) - lateral_points = lateral.numpy().reshape(-1, 2) - convex_hull_points = np.concatenate((primary_points, lateral_points), axis=0) - - ( - perimeters, - areas, - max_widths, - max_heights, - ) = get_convhull_features(convex_hull_points) + perimeter = get_chull_perimeter(convex_hull) + area = get_chull_area(convex_hull) + max_width = get_chull_max_width(convex_hull) + max_height = get_chull_max_height(convex_hull) - np.testing.assert_almost_equal(perimeters, 1910.0476127930017, decimal=3) - np.testing.assert_almost_equal(areas, 93255.32153574759, decimal=3) - np.testing.assert_almost_equal(max_widths, 211.279296875, decimal=3) - np.testing.assert_almost_equal(max_heights, 876.5622253417969, decimal=3) + np.testing.assert_almost_equal(perimeter, 1910.0476127930017, decimal=3) + np.testing.assert_almost_equal(area, 93255.32153574759, decimal=3) + np.testing.assert_almost_equal(max_width, 211.279296875, decimal=3) + np.testing.assert_almost_equal(max_height, 876.5622253417969, decimal=3) # test rice model def test_get_convhull_features_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] + primary_pts = primary.numpy() + lateral_pts = lateral.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + pts = get_all_pts_array(primary_max_length_pts, lateral_pts) + convex_hull = get_convhull(pts) - primary_points = primary.numpy().reshape(-1, 2) - lateral_points = lateral.numpy().reshape(-1, 2) - convex_hull_points = np.concatenate((primary_points, lateral_points), axis=0) + perimeter = get_chull_perimeter(convex_hull) + area = get_chull_area(convex_hull) + max_width = get_chull_max_width(convex_hull) + max_height = get_chull_max_height(convex_hull) - ( - perimeters, - areas, - max_widths, - max_heights, - ) = get_convhull_features(convex_hull_points) - - np.testing.assert_almost_equal(perimeters, 1458.8585933576614, decimal=3) - np.testing.assert_almost_equal(areas, 23878.72090798154, decimal=3) - np.testing.assert_almost_equal(max_widths, 64.4229736328125, decimal=3) - np.testing.assert_almost_equal(max_heights, 720.0375061035156, decimal=3) + np.testing.assert_almost_equal(perimeter, 1458.8585933576614, decimal=3) + np.testing.assert_almost_equal(area, 23878.72090798154, decimal=3) + np.testing.assert_almost_equal(max_width, 64.4229736328125, decimal=3) + np.testing.assert_almost_equal(max_height, 720.0375061035156, decimal=3) # test plant with 2 roots/instances with nan nodes def test_get_convhull_features_nan(pts_nan31_5node): - ( - perimeters, - areas, - max_widths, - max_heights, - ) = get_convhull_features(pts_nan31_5node) + convex_hull = get_convhull(pts_nan31_5node) + + perimeter = get_chull_perimeter(convex_hull) + area = get_chull_area(convex_hull) + max_width = get_chull_max_width(convex_hull) + max_height = get_chull_max_height(convex_hull) - np.testing.assert_almost_equal(perimeters, 1184.6684128638494, decimal=3) - np.testing.assert_almost_equal(areas, 2276.1159928281368, decimal=3) - np.testing.assert_almost_equal(max_widths, 35.46612548999997, decimal=3) - np.testing.assert_almost_equal(max_heights, 591.16937256, decimal=3) + np.testing.assert_almost_equal(perimeter, 1184.6684128638494, decimal=3) + np.testing.assert_almost_equal(area, 2276.1159928281368, decimal=3) + np.testing.assert_almost_equal(max_width, 35.46612548999997, decimal=3) + np.testing.assert_almost_equal(max_height, 591.16937256, decimal=3) # test plant with 1 root/instance with only 2 non-nan nodes def test_get_convhull_features_nanall(pts_nan_5node): - ( - perimeters, - areas, - max_widths, - max_heights, - ) = get_convhull_features(pts_nan_5node) + convex_hull = get_convhull(pts_nan_5node) - np.testing.assert_almost_equal(perimeters, np.nan, decimal=3) - np.testing.assert_almost_equal(areas, np.nan, decimal=3) - np.testing.assert_almost_equal(max_widths, np.nan, decimal=3) - np.testing.assert_almost_equal(max_heights, np.nan, decimal=3) + perimeter = get_chull_perimeter(convex_hull) + area = get_chull_area(convex_hull) + max_width = get_chull_max_width(convex_hull) + max_height = get_chull_max_height(convex_hull) + + np.testing.assert_almost_equal(perimeter, np.nan, decimal=3) + np.testing.assert_almost_equal(area, np.nan, decimal=3) + np.testing.assert_almost_equal(max_width, np.nan, decimal=3) + np.testing.assert_almost_equal(max_height, np.nan, decimal=3) # test get_chull_perimeter with defined lateral_pts @@ -171,52 +169,18 @@ def test_get_chull_perimeter(lateral_pts): np.testing.assert_almost_equal(perimeter, 1184.7141710619985, decimal=3) -# test get_chull_perimeter with canola -def test_get_chull_perimeter_canola(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts = get_all_pts_array(plant=plant, frame=0, monocots=False) - perimeter = get_chull_perimeter(pts) - np.testing.assert_almost_equal(perimeter, 1910.0476127930017, decimal=3) - - -# test get_chull_area with canola -def test_get_chull_area_canola(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts = get_all_pts_array(plant=plant, frame=0, monocots=False) - area = get_chull_area(pts) - np.testing.assert_almost_equal(area, 93255.32153574759, decimal=3) - - -# test get_chull_max_width with canola -def test_get_chull_max_width(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts = get_all_pts_array(plant=plant, frame=0, monocots=False) - max_width = get_chull_max_width(pts) - np.testing.assert_almost_equal(max_width, 211.279296875, decimal=3) - - -def test_get_chull_max_height(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts = get_all_pts_array(plant=plant, frame=0, monocots=False) - max_height = get_chull_max_height(pts) - np.testing.assert_almost_equal(max_height, 876.5622253417969, decimal=3) - - # test get_chull_line_lengths with canola def test_get_chull_line_lengths(canola_h5): - plant = Series.load( + series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts = get_all_pts_array(plant=plant, frame=0, monocots=False) - chull_line_lengths = get_chull_line_lengths(pts) + primary, lateral = series[0] + primary_pts = primary.numpy() + lateral_pts = lateral.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + pts = get_all_pts_array(primary_max_length_pts, lateral_pts) + convex_hull = get_convhull(pts) + chull_line_lengths = get_chull_line_lengths(convex_hull) assert chull_line_lengths.shape[0] == 10 np.testing.assert_almost_equal(chull_line_lengths[0], 227.553, decimal=3) diff --git a/tests/test_ellipse.py b/tests/test_ellipse.py index 564b1ec..6a8ef6d 100644 --- a/tests/test_ellipse.py +++ b/tests/test_ellipse.py @@ -6,6 +6,7 @@ get_ellipse_b, get_ellipse_ratio, ) +from sleap_roots.lengths import get_max_length_pts from sleap_roots.points import get_all_pts_array @@ -36,7 +37,13 @@ def test_get_ellipse_a(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) ellipse_a = get_ellipse_a(pts_all_array) np.testing.assert_almost_equal(ellipse_a, 398.1275346610801, decimal=3) @@ -45,7 +52,13 @@ def test_get_ellipse_b(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) ellipse_b = get_ellipse_b(pts_all_array) np.testing.assert_almost_equal(ellipse_b, 115.03734180292595, decimal=3) @@ -54,6 +67,28 @@ def test_get_ellipse_ratio(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) ellipse_ratio = get_ellipse_ratio(pts_all_array) np.testing.assert_almost_equal(ellipse_ratio, 3.460854783511295, decimal=3) + + +def test_get_ellipse_ratio_ellipse(canola_h5): + plant = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) + ellipse = fit_ellipse(pts_all_array) + ellipse_ratio = get_ellipse_ratio(ellipse) + np.testing.assert_almost_equal(ellipse_ratio, 3.460854783511295, decimal=3) diff --git a/tests/test_graphpipeline.py b/tests/test_graphpipeline.py deleted file mode 100644 index befc797..0000000 --- a/tests/test_graphpipeline.py +++ /dev/null @@ -1,139 +0,0 @@ -from sleap_roots.graphpipeline import ( - get_traits_value_frame, - get_traits_value_plant, - get_traits_value_plant_summary, - get_all_plants_traits, -) -import pytest -import numpy as np -import pandas as pd - - -@pytest.fixture -def primary_pts(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [850.17755127, 472.83520508], - [844.45300293, 472.83520508], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ] - ] - ) - - -@pytest.fixture -def lateral_pts(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - [ - [852.17755127, 216.95648193], - [844.45300293, 472.83520508], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - ] - ) - - -def test_get_traits_value_frame(primary_pts, lateral_pts): - monocots = False - pts_all_array = ( - primary_pts.reshape(-1, 2) - if monocots - else np.concatenate( - (primary_pts.reshape(-1, 2), lateral_pts.reshape(-1, 2)), axis=0 - ) - ) - pts_all_array = pts_all_array.reshape( - (1, pts_all_array.shape[0], pts_all_array.shape[1]) - ) - pts_all_list = ( - primary_pts if monocots else primary_pts.tolist() + lateral_pts.tolist() - ) - - data = get_traits_value_frame( - primary_pts, lateral_pts, pts_all_array, pts_all_list, monocots - ) - assert len(data) == 43 - - -def test_get_traits_value_plant(canola_h5): - monocots = False - - data_plant, data_plant_df, plant_name = get_traits_value_plant( - canola_h5, - monocots, - primary_name="primary_multi_day", - lateral_name="lateral_3_nodes", - stem_width_tolerance=0.02, - n_line=50, - network_fraction=2 / 3, - write_csv=False, - ) - assert len(data_plant) == 72 - assert data_plant_df.shape[1] == 45 - assert plant_name == "919QDUH" - - -def test_get_traits_value_plant_summary(canola_h5): - monocots = False - data_plant_summary = get_traits_value_plant_summary( - canola_h5, - monocots, - primary_name="primary_multi_day", - lateral_name="lateral_3_nodes", - stem_width_tolerance=0.02, - n_line=50, - network_fraction=2 / 3, - write_csv=False, - write_summary_csv=False, - ) - assert data_plant_summary.shape[0] == 1 - assert data_plant_summary.shape[1] == 1036 - np.testing.assert_almost_equal(data_plant_summary.iloc[0, 5], 16.643764612148875) - - -def test_get_all_plants_traits_dicot(canola_folder): - data_folders = [canola_folder] - primary_name = "primary_multi_day" - lateral_name = "lateral_3_nodes" - write_per_plant_details = True - write_per_plant_summary = True - all_traits_df = get_all_plants_traits( - data_folders=data_folders, - primary_name=primary_name, - lateral_name=lateral_name, - write_per_plant_details=write_per_plant_details, - write_per_plant_summary=write_per_plant_summary, - ) - assert all_traits_df.shape == (1, 1037) - np.testing.assert_almost_equal(all_traits_df.iloc[0, 5], 16.643764612148875) - - -def tests_get_all_plants_traits_monocot(rice_folder): - data_folders = [rice_folder] - primary_name = "longest_3do_6nodes" - lateral_name = "main_3do_6nodes" - write_per_plant_details = True - write_per_plant_summary = True - all_traits_df = get_all_plants_traits( - data_folders=data_folders, - primary_name=primary_name, - lateral_name=lateral_name, - write_per_plant_details=write_per_plant_details, - write_per_plant_summary=write_per_plant_summary, - ) - assert all_traits_df.shape == (1, 1037) - np.testing.assert_almost_equal(all_traits_df.iloc[0, 5], 3.716619501198254) diff --git a/tests/test_lengths.py b/tests/test_lengths.py new file mode 100644 index 0000000..ba05fc2 --- /dev/null +++ b/tests/test_lengths.py @@ -0,0 +1,229 @@ +from sleap_roots.lengths import ( + get_grav_index, + get_root_lengths, + get_root_lengths_max, + get_max_length_pts, +) +from sleap_roots.bases import get_base_tip_dist, get_bases +from sleap_roots.tips import get_tips +from sleap_roots import Series +import numpy as np +import pytest + + +@pytest.fixture +def pts_standard(): + return np.array( + [ + [ + [1, 2], + [3, 4], + ], + [ + [5, 6], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def pts_no_bases(): + return np.array( + [ + [ + [np.nan, np.nan], + [3, 4], + ], + [ + [np.nan, np.nan], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def pts_one_base(): + return np.array( + [ + [ + [1, 2], + [3, 4], + ], + [ + [np.nan, np.nan], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def pts_no_roots(): + return np.array( + [ + [ + [np.nan, np.nan], + [np.nan, np.nan], + ], + [ + [np.nan, np.nan], + [np.nan, np.nan], + ], + ] + ) + + +@pytest.fixture +def pts_not_contiguous(): + return np.array( + [ + [ + [1, 2], + [np.nan, np.nan], + [3, 4], + ], + [ + [5, 6], + [np.nan, np.nan], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def primary_pts(): + return np.array( + [ + [ + [852.17755127, 216.95648193], + [850.17755127, 472.83520508], + [844.45300293, 472.83520508], + [837.03405762, 588.5123291], + [828.87963867, 692.72009277], + [816.71142578, 808.12585449], + ] + ] + ) + + +@pytest.fixture +def lateral_pts(): + return np.array( + [ + [ + [852.17755127, 216.95648193], + [np.nan, np.nan], + [837.03405762, 588.5123291], + [828.87963867, 692.72009277], + [816.71142578, 808.12585449], + ], + [ + [852.17755127, 216.95648193], + [844.45300293, 472.83520508], + [837.03405762, 588.5123291], + [828.87963867, 692.72009277], + [816.71142578, 808.12585449], + ], + ] + ) + + +@pytest.fixture +def lengths_normal(): + return np.array([145, 234, 329.4]) + + +@pytest.fixture +def lengths_with_nan(): + return np.array([145, 234, 329.4, np.nan]) + + +@pytest.fixture +def lengths_all_nan(): + return np.array([np.nan, np.nan, np.nan]) + + +# test get_grav_index function +def test_get_grav_index(canola_h5): + series = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary = series[0][0] # first frame, primary labels + primary_pts = primary.numpy() # primary points as numpy array + primary_length = get_root_lengths_max(primary_pts) + max_length_pts = get_max_length_pts(primary_pts) + bases = get_bases(max_length_pts) + tips = get_tips(max_length_pts) + base_tip_dist = get_base_tip_dist(bases, tips) + grav_index = get_grav_index(primary_length, base_tip_dist) + np.testing.assert_almost_equal(grav_index, 0.08898137324716636) + + +def test_get_root_lengths(canola_h5): + series = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary, lateral = series[0] + pts = primary.numpy() + assert pts.shape == (1, 6, 2) + + root_lengths = get_root_lengths(pts) + assert np.isscalar(root_lengths) + np.testing.assert_array_almost_equal(root_lengths, [971.050417]) + + pts = lateral.numpy() + assert pts.shape == (5, 3, 2) + + root_lengths = get_root_lengths(pts) + assert root_lengths.shape == (5,) + np.testing.assert_array_almost_equal( + root_lengths, [20.129579, 62.782368, 80.268003, 34.925591, 3.89724] + ) + + +def test_get_root_lengths_no_roots(pts_no_bases): + root_lengths = get_root_lengths(pts_no_bases) + assert root_lengths.shape == (2,) + np.testing.assert_array_almost_equal(root_lengths, np.array([np.nan, np.nan])) + + +def test_get_root_lengths_one_point(pts_one_base): + root_lengths = get_root_lengths(pts_one_base) + assert root_lengths.shape == (2,) + np.testing.assert_array_almost_equal( + root_lengths, np.array([2.82842712475, np.nan]) + ) + + +# test get_root_lengths_max function with lengths_normal +def test_get_root_lengths_max_normal(lengths_normal): + max_length = get_root_lengths_max(lengths_normal) + np.testing.assert_array_almost_equal(max_length, 329.4) + + +# test get_root_lengths_max function with lengths_with_nan +def test_get_root_lengths_max_with_nan(lengths_with_nan): + max_length = get_root_lengths_max(lengths_with_nan) + np.testing.assert_array_almost_equal(max_length, 329.4) + + +# test get_root_lengths_max function with lengths_all_nan +def test_get_root_lengths_max_all_nan(lengths_all_nan): + max_length = get_root_lengths_max(lengths_all_nan) + np.testing.assert_array_almost_equal(max_length, np.nan) + + +def test_get_max_length_pts(canola_h5): + series = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary = series[0][0] # first frame, primary labels + primary_pts = primary.numpy() # primary points as numpy array + max_length_pts = get_max_length_pts(primary_pts) + assert max_length_pts.shape == (6, 2) + np.testing.assert_almost_equal( + max_length_pts[0], np.array([1016.7844238, 144.4191589]) + ) diff --git a/tests/test_networklength.py b/tests/test_networklength.py index 6735d87..3126ceb 100644 --- a/tests/test_networklength.py +++ b/tests/test_networklength.py @@ -1,6 +1,8 @@ import pytest import numpy as np from sleap_roots import Series +from sleap_roots.convhull import get_chull_area, get_convhull +from sleap_roots.lengths import get_max_length_pts, get_root_lengths from sleap_roots.networklength import get_bbox from sleap_roots.networklength import get_network_distribution from sleap_roots.networklength import get_network_distribution_ratio @@ -81,115 +83,193 @@ def test_get_network_width_depth_ratio_nan(pts_nan3): np.testing.assert_almost_equal(ratio, np.nan, decimal=7) -def test_get_network_solidity(canola_h5): +def test_get_network_length(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=False) + lateral_lengths = get_root_lengths(lateral_pts) monocots = False - ratio = get_network_solidity(primary_pts, lateral_pts, pts_all_array, monocots) - np.testing.assert_almost_equal(ratio, 0.012578941125511587, decimal=7) + length = get_network_length(primary_length, lateral_lengths, monocots) + np.testing.assert_almost_equal(length, 1173.0531992388217, decimal=7) -def test_get_network_solidity_rice(rice_h5): +def test_get_network_length_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=True) + lateral_lengths = get_root_lengths(lateral_pts) monocots = True - ratio = get_network_solidity(primary_pts, lateral_pts, pts_all_array, monocots) - np.testing.assert_almost_equal(ratio, 0.17930631242462894, decimal=7) + length = get_network_length(primary_length, lateral_lengths, monocots) + np.testing.assert_almost_equal(length, 798.5726441151357, decimal=7) -def test_get_network_distribution(canola_h5): +def test_get_network_solidity(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=False) - fraction = 2 / 3 + lateral_lengths = get_root_lengths(lateral_pts) monocots = False - root_length = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + network_length = get_network_length(primary_length, lateral_lengths, monocots) + + # get chull_area + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots ) - np.testing.assert_almost_equal(root_length, 589.4322131363684, decimal=7) + convex_hull = get_convhull(pts_all_array) + chull_area = get_chull_area(convex_hull) + + ratio = get_network_solidity(network_length, chull_area) + np.testing.assert_almost_equal(ratio, 0.012578941125511587, decimal=7) -def test_get_network_distribution_rice(rice_h5): +def test_get_network_solidity_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=True) - fraction = 2 / 3 + lateral_lengths = get_root_lengths(lateral_pts) monocots = True - root_length = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + network_length = get_network_length(primary_length, lateral_lengths, monocots) + + # get chull_area + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots ) - np.testing.assert_almost_equal(root_length, 475.89810040497025, decimal=7) + convex_hull = get_convhull(pts_all_array) + chull_area = get_chull_area(convex_hull) + ratio = get_network_solidity(network_length, chull_area) + np.testing.assert_almost_equal(ratio, 0.03366254601775008, decimal=7) -def test_get_network_length(canola_h5): + +def test_get_network_distribution(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) lateral_pts = lateral.numpy() monocots = False - length = get_network_length(primary_pts, lateral_pts, monocots) - np.testing.assert_almost_equal(length, 1173.0531992388217, decimal=7) + pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts, monocots) + bbox = get_bbox(pts_all_array) + fraction = 2 / 3 + monocots = False + root_length = get_network_distribution( + primary_max_length_pts, lateral_pts, bbox, fraction, monocots + ) + np.testing.assert_almost_equal(root_length, 589.4322131363684, decimal=7) -def test_get_network_length_rice(rice_h5): +def test_get_network_distribution_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) lateral_pts = lateral.numpy() monocots = True - length = get_network_length(primary_pts, lateral_pts, monocots) - np.testing.assert_almost_equal(length, 798.5726441151357, decimal=7) + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + bbox = get_bbox(pts_all_array) + fraction = 2 / 3 + root_length = get_network_distribution( + primary_max_length_pts, lateral_pts, bbox, fraction, monocots + ) + np.testing.assert_almost_equal(root_length, 477.77168597561507, decimal=7) def test_get_network_distribution_ratio(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) + monocots = False primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=False) + lateral_lengths = get_root_lengths(lateral_pts) + # get pts_all_array + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + bbox = get_bbox(pts_all_array) + # get network_length_lower + network_length_lower = get_network_distribution( + primary_max_length_pts, lateral_pts, bbox + ) fraction = 2 / 3 - monocots = False ratio = get_network_distribution_ratio( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + primary_length, + lateral_lengths, + network_length_lower, + fraction, + monocots, ) np.testing.assert_almost_equal(ratio, 0.5024769665338648, decimal=7) def test_get_network_distribution_ratio_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) + monocots = True + fraction = 2 / 3 primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=True) - fraction = 2 / 3 - monocots = True + lateral_lengths = get_root_lengths(lateral_pts) + # get pts_all_array + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + bbox = get_bbox(pts_all_array) + # get network_length_lower + network_length_lower = get_network_distribution( + primary_max_length_pts, lateral_pts, bbox, fraction=fraction, monocots=monocots + ) ratio = get_network_distribution_ratio( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + primary_length, + lateral_lengths, + network_length_lower, + fraction, + monocots, ) - np.testing.assert_almost_equal(ratio, 0.5959358912579489, decimal=7) + + np.testing.assert_almost_equal(ratio, 0.5982820592421038, decimal=7) diff --git a/tests/test_points.py b/tests/test_points.py index 4944933..3fdf368 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -1,205 +1,41 @@ -import numpy as np -import pytest from sleap_roots import Series +from sleap_roots.lengths import get_max_length_pts, get_root_lengths from sleap_roots.points import ( - get_pt_ind, - get_primary_pts, - get_lateral_pts, - get_all_pts, get_all_pts_array, ) -@pytest.fixture -def pts_nan2(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, 472.83520508], - [844.45300293, 472.83520508], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ] - ] - ) - - -@pytest.fixture -def pts_nan3(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ] - ] - ) - - -@pytest.fixture -def pts_nan6(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [np.nan, np.nan], - ] - ] - ) - - -@pytest.fixture -def pts_nanall(): - return np.array( - [ - [ - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - ] - ] - ) - - -@pytest.fixture -def pts_nan32(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [844.45300293, 472.83520508], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - ] - ) - - -@pytest.fixture -def pts_nan32_5node(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - ] - ) - - -# test get_pt_ind function -def test_get_pt_ind(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - proximal = True - node_ind = get_pt_ind(pts, proximal) - np.testing.assert_array_equal(node_ind, [1]) - - -# test get_pt_ind function using root that without second node -def test_get_pt_ind_nan2(pts_nan2): - proximal = True - node_ind = get_pt_ind(pts_nan2, proximal) - np.testing.assert_array_equal(node_ind, [2]) - - -# test get_pt_ind function using root that without second and third nodes -def test_get_pt_ind_nan3(pts_nan3): - proximal = True - node_ind = get_pt_ind(pts_nan3, proximal) - np.testing.assert_array_equal(node_ind, [3]) - - -# test get_pt_ind function using two roots/instances -def test_get_pt_ind_nan32(pts_nan32): - proximal = True - node_ind = get_pt_ind(pts_nan32, proximal) - np.testing.assert_array_equal(node_ind, [3, 2]) - - -# test get_pt_ind function using root that without last node -def test_get_pt_ind_nan6(pts_nan6): - proximal = False - node_ind = get_pt_ind(pts_nan6, proximal) - np.testing.assert_array_equal(node_ind, [4]) - - -# test get_pt_ind function using root with all nan node -def test_get_pt_ind_nanall(pts_nanall): - proximal = False - node_ind = get_pt_ind(pts_nanall, proximal) - np.testing.assert_array_equal(node_ind, [0]) - - -# test get_primary_pts function -def test_get_primary_pts(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts_pr = get_primary_pts(plant=plant, frame=0) - assert pts_pr.shape == (1, 6, 2) - - -# test get_lateral_pts function -def test_get_lateral_pts(canola_h5): +# test get_all_pts_array function +def test_get_all_pts_array(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - assert pts_lr.shape == (5, 3, 2) - - -# test get_all_pts function -def test_get_all_pts(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + primary, lateral = plant[0] + primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + # get lateral_lengths + lateral_pts = lateral.numpy() + monocots = False + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots ) - pts_all = get_all_pts(plant=plant, frame=0, monocots=False) - assert len(pts_all) == 6 - assert len(pts_all[0]) == 6 - assert len(pts_all[1]) == 3 + assert pts_all_array.shape[0] == 21 # test get_all_pts_array function -def test_get_all_pts_array(canola_h5): +def test_get_all_pts_array_rice(rice_h5): plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) - assert pts_all_array.shape[0] == 21 + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + primary, lateral = plant[0] + primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + # get lateral_lengths + lateral_pts = lateral.numpy() + monocots = True + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + assert pts_all_array.shape[0] == 12 diff --git a/tests/test_scanline.py b/tests/test_scanline.py index fd5cfa2..a621b87 100644 --- a/tests/test_scanline.py +++ b/tests/test_scanline.py @@ -1,6 +1,7 @@ import pytest import numpy as np from sleap_roots import Series +from sleap_roots.lengths import get_max_length_pts from sleap_roots.scanline import ( get_scanline_first_ind, get_scanline_last_ind, @@ -60,6 +61,7 @@ def test_count_scanline_intersections_canola(canola_h5): primary, lateral = series[0] primary_pts = primary.numpy() lateral_pts = lateral.numpy() + primary_pts = get_max_length_pts(primary_pts) depth = 1080 width = 2048 n_line = 50 @@ -73,11 +75,12 @@ def test_count_scanline_intersections_canola(canola_h5): def test_count_scanline_intersections_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() lateral_pts = lateral.numpy() + primary_pts = get_max_length_pts(primary_pts) depth = 1080 width = 2048 n_line = 50 @@ -86,7 +89,7 @@ def test_count_scanline_intersections_rice(rice_h5): primary_pts, lateral_pts, depth, width, n_line, monocots ) assert n_inter.shape == (50,) - np.testing.assert_equal(n_inter[14], 1) + np.testing.assert_equal(n_inter[14], 2) # test get_scanline_first_ind with canola @@ -97,14 +100,16 @@ def test_get_scanline_first_ind(canola_h5): primary, lateral = plant[0] primary_pts = primary.numpy() lateral_pts = lateral.numpy() + primary_pts = get_max_length_pts(primary_pts) depth = 1080 width = 2048 n_line = 50 monocots = False - scanline_first_ind = get_scanline_first_ind( + scanline_intersection_counts = count_scanline_intersections( primary_pts, lateral_pts, depth, width, n_line, monocots ) - np.testing.assert_equal(scanline_first_ind, 6) + scanline_first_ind = get_scanline_first_ind(scanline_intersection_counts) + np.testing.assert_equal(scanline_first_ind, 7) # test get_scanline_last_ind with canola @@ -115,11 +120,13 @@ def test_get_scanline_last_ind(canola_h5): primary, lateral = plant[0] primary_pts = primary.numpy() lateral_pts = lateral.numpy() + primary_pts = get_max_length_pts(primary_pts) depth = 1080 width = 2048 n_line = 50 monocots = True - scanline_last_ind = get_scanline_last_ind( + scanline_intersection_counts = count_scanline_intersections( primary_pts, lateral_pts, depth, width, n_line, monocots ) - np.testing.assert_equal(scanline_last_ind, 15) + scanline_last_ind = get_scanline_last_ind(scanline_intersection_counts) + np.testing.assert_equal(scanline_last_ind, 12) diff --git a/tests/test_summary.py b/tests/test_summary.py index 34af557..6484df1 100644 --- a/tests/test_summary.py +++ b/tests/test_summary.py @@ -1,27 +1,41 @@ import numpy as np -import pytest from sleap_roots.summary import get_summary -@pytest.fixture -def array_random(): - np.random.seed(0) - return np.random.rand(100) +def test_get_summary(): + summary = get_summary(np.array([-1, 0, 1])) + assert summary["min"] == -1 + assert summary["max"] == 1 + assert summary["mean"] == 0 + assert summary["median"] == 0 + assert summary["std"] == np.std([-1, 0, 1]) + assert summary["p5"] == np.percentile([-1, 0, 1], 5) + assert summary["p25"] == np.percentile([-1, 0, 1], 25) + assert summary["p75"] == np.percentile([-1, 0, 1], 75) + assert summary["p95"] == np.percentile([-1, 0, 1], 95) -# test get_summary function with random array -def test_get_summary(array_random): - [ - trait_min, - trait_max, - trait_mean, - trait_median, - trait_std, - trait_prc5, - trait_prc25, - trait_prc75, - trait_prc95, - ] = get_summary(array_random) - np.testing.assert_almost_equal(trait_min, 0.004695476192547066, decimal=3) - np.testing.assert_almost_equal(trait_mean, 0.4727938395125177, decimal=3) - np.testing.assert_almost_equal(trait_prc95, 0.9456186092221561, decimal=3) +def test_get_summary_empty(): + summary = get_summary([]) + np.testing.assert_array_equal(summary["min"], np.nan) + np.testing.assert_array_equal(summary["max"], np.nan) + np.testing.assert_array_equal(summary["mean"], np.nan) + np.testing.assert_array_equal(summary["median"], np.nan) + np.testing.assert_array_equal(summary["std"], np.nan) + np.testing.assert_array_equal(summary["p5"], np.nan) + np.testing.assert_array_equal(summary["p25"], np.nan) + np.testing.assert_array_equal(summary["p75"], np.nan) + np.testing.assert_array_equal(summary["p95"], np.nan) + + +def test_get_summary_prefix(): + summary = get_summary([], prefix="test_") + assert "test_min" in summary + assert "test_max" in summary + assert "test_mean" in summary + assert "test_median" in summary + assert "test_std" in summary + assert "test_p5" in summary + assert "test_p25" in summary + assert "test_p75" in summary + assert "test_p95" in summary diff --git a/tests/test_tips.py b/tests/test_tips.py index 19c299a..35c853e 100644 --- a/tests/test_tips.py +++ b/tests/test_tips.py @@ -1,5 +1,4 @@ -from sleap_roots.points import get_lateral_pts -from sleap_roots.tips import get_tips, get_primary_depth, get_tip_xs, get_tip_ys +from sleap_roots.tips import get_tips, get_tip_xs, get_tip_ys from sleap_roots import Series import numpy as np import pytest @@ -81,25 +80,14 @@ def test_tips_one_tip(pts_one_tip): np.testing.assert_array_equal(tips, [[3, 4], [np.nan, np.nan]]) -# test get_primary_depth with standard points -def test_get_primary_depth_standard(pt_standard): - primary_depth = get_primary_depth(pt_standard) - np.testing.assert_array_almost_equal(primary_depth, 4) - - -# test get_primary_depth with nan tip points -def test_get_primary_depth_nan(pt_nan_tip): - primary_depth = get_primary_depth(pt_nan_tip) - np.testing.assert_array_almost_equal(primary_depth, np.nan) - - # test get_tip_xs with canola def test_get_tip_xs_canola(canola_h5): - plant = Series.load( + series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - tip_xs = get_tip_xs(pts_lr) + lateral = series[0][1] # LabeledFrame + lateral_pts = lateral.numpy() # Lateral roots as a numpy array + tip_xs = get_tip_xs(lateral_pts) assert tip_xs.shape[0] == 5 np.testing.assert_almost_equal(tip_xs[1], 1072.6610107421875, decimal=3) @@ -121,18 +109,21 @@ def test_get_tip_xs_no_tip(pts_no_tips): # test get_tip_ys with canola def test_get_tip_ys_canola(canola_h5): - plant = Series.load( + series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - tip_ys = get_tip_ys(pts_lr) + lateral = series[0][1] # LabeledFrame + lateral_pts = lateral.numpy() # Lateral roots as a numpy array + tips = get_tips(lateral_pts) + tip_ys = get_tip_ys(tips) assert tip_ys.shape[0] == 5 np.testing.assert_almost_equal(tip_ys[1], 276.51275634765625, decimal=3) # test get_tip_ys with standard points def test_get_tip_ys_standard(pts_standard): - tip_ys = get_tip_ys(pts_standard) + tips = get_tips(pts_standard) + tip_ys = get_tip_ys(tips) assert tip_ys.shape[0] == 2 np.testing.assert_almost_equal(tip_ys[0], 4, decimal=3) np.testing.assert_almost_equal(tip_ys[1], 8, decimal=3) @@ -140,6 +131,7 @@ def test_get_tip_ys_standard(pts_standard): # test get_tip_ys with no tips def test_get_tip_ys_no_tip(pts_no_tips): - tip_ys = get_tip_ys(pts_no_tips) + tips = get_tips(pts_no_tips) + tip_ys = get_tip_ys(tips) assert tip_ys.shape[0] == 2 np.testing.assert_almost_equal(tip_ys[1], np.nan, decimal=3) diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py new file mode 100644 index 0000000..0d4eac5 --- /dev/null +++ b/tests/test_trait_pipelines.py @@ -0,0 +1,20 @@ +from sleap_roots.trait_pipelines import DicotPipeline +from sleap_roots.series import Series + + +def test_dicot_pipeline(canola_h5, soy_h5): + canola = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + soy = Series.load( + soy_h5, primary_name="primary_multi_day", lateral_name="lateral__nodes" + ) + + pipeline = DicotPipeline() + canola_traits = pipeline.compute_plant_traits(canola) + soy_traits = pipeline.compute_plant_traits(soy) + all_traits = pipeline.compute_batch_traits([canola, soy]) + + assert canola_traits.shape == (72, 115) + assert soy_traits.shape == (72, 115) + assert all_traits.shape == (2, 1018) diff --git a/tests/test_traitsgraph.py b/tests/test_traitsgraph.py deleted file mode 100644 index b4faa53..0000000 --- a/tests/test_traitsgraph.py +++ /dev/null @@ -1,7 +0,0 @@ -from sleap_roots.traitsgraph import get_traits_graph - - -def test_get_traits_graph(): - dts = get_traits_graph() - assert len(dts) == 43 - assert dts[0] == "primary_base_pt"