diff --git a/sleap_roots/__init__.py b/sleap_roots/__init__.py index 3f0c719..7a9ec45 100644 --- a/sleap_roots/__init__.py +++ b/sleap_roots/__init__.py @@ -10,9 +10,8 @@ import sleap_roots.scanline import sleap_roots.series import sleap_roots.summary -import sleap_roots.traitsgraph -import sleap_roots.graphpipeline -from sleap_roots.graphpipeline import get_all_plants_traits +import sleap_roots.trait_pipeline +from sleap_roots.trait_pipeline import get_all_plants_traits from sleap_roots.series import Series # Define package version. diff --git a/sleap_roots/angle.py b/sleap_roots/angle.py index b50ed10..29985c4 100644 --- a/sleap_roots/angle.py +++ b/sleap_roots/angle.py @@ -4,51 +4,109 @@ import math -def get_node_ind(pts: np.ndarray, proximal=True) -> np.ndarray: - """Find nproximal/distal node index. +def get_node_ind(pts: np.ndarray, proximal: bool = True) -> np.ndarray: + """Find proximal/distal node index. Args: - pts: Numpy array of points of shape (instances, nodes, 2). + pts: Numpy array of points of shape (instances, nodes, 2) or (nodes, 2). proximal: Boolean value, where true is proximal (default), false is distal. Returns: An array of shape (instances,) of proximal or distal node index. """ - node_ind = [] - for i in range(pts.shape[0]): - ind = 1 if proximal else pts.shape[1] - 1 # set initial proximal/distal node - while np.isnan(pts[i, ind]).any(): - ind += 1 if proximal else -1 - if (ind == pts.shape[1] and proximal) or (ind == 0 and not proximal): - break - node_ind.append(ind) + # Check if pts is a numpy array + if not isinstance(pts, np.ndarray): + raise TypeError("Input pts should be a numpy array.") + + # Check if pts has 2 or 3 dimensions + if pts.ndim not in [2, 3]: + raise ValueError("Input pts should have 2 or 3 dimensions.") + + # Check if the last dimension of pts has size 2 + if pts.shape[-1] != 2: + raise ValueError( + "The last dimension of the input pts should have size 2," + "representing x and y coordinates." + ) + + # Check if pts is 2D, if so, reshape to 3D + if pts.ndim == 2: + pts = pts[np.newaxis, ...] + + # Identify where NaN values exist + nan_mask = np.isnan(pts).any(axis=-1) + + # If only NaN values, return NaN + if nan_mask.all(): + return np.nan + + if proximal: + # For proximal, we want the first non-NaN node in the first half root + # get the first half nan mask (exclude the base node) + node_proximal = nan_mask[:, 1 : int((nan_mask.shape[1] + 1) / 2)] + # get the nearest non-Nan node index + node_ind = np.argmax(~node_proximal, axis=-1) + # if there is no non-Nan node, set value of 99 + node_ind[node_proximal.all(axis=1)] = 99 + node_ind = node_ind + 1 # adjust indices by adding one (base node) + else: + # For distal, we want the last non-NaN node in the last half root + # get the last half nan mask + node_distal = nan_mask[:, int(nan_mask.shape[1] / 2) :] + # get the farest non-Nan node + node_ind = (node_distal[:, ::-1] == False).argmax(axis=1) + node_ind[node_distal.all(axis=1)] = -95 # set value if no non-Nan node + node_ind = pts.shape[1] - node_ind - 1 # adjust indices by reversing + + # reset indices of 0 (base node) if no non-Nan node + node_ind[node_ind == 100] = 0 + + # If pts was originally 2D, return a scalar instead of a single-element array + if pts.shape[0] == 1: + return node_ind[0] + + # If only one root, return a scalar instead of a single-element array + if node_ind.shape[0] == 1: + return node_ind[0] + return node_ind -def get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray: +def get_root_angle( + pts: np.ndarray, node_ind: np.ndarray, proximal: bool = True, base_ind=0 +) -> np.ndarray: """Find angles for each root. Args: pts: Numpy array of points of shape (instances, nodes, 2). + node_ind: Primary or lateral root node index. proximal: Boolean value, where true is proximal (default), false is distal. base_ind: Index of base node in the skeleton (default: 0). Returns: An array of shape (instances,) of angles in degrees, modulo 360. """ - node_ind = get_node_ind(pts, proximal) # get proximal or distal node index + # if node_ind is a single int value, make it as array to keep consistent + if not isinstance(node_ind, np.ndarray): + node_ind = [node_ind] + + if np.isnan(node_ind).all(): + return np.nan + angs_root = [] for i in range(len(node_ind)): - # filter out the cases if all nan nodes in last/first half part - # to calculate proximal/distal angle - if (node_ind[i] < math.ceil(pts.shape[1] / 2) and proximal) or ( - node_ind[i] >= math.floor(pts.shape[1] / 2) and not (proximal) - ): + # if the node_ind is 0, do NOT calculate angs + if node_ind[i] == 0: + angs = np.nan + else: xy = pts[i, node_ind[i], :] - pts[i, base_ind, :] # center on base node # calculate the angle and convert to the start with gravity direction ang = np.arctan2(-xy[1], xy[0]) * 180 / np.pi angs = abs(ang + 90) if ang < 90 else abs(-(360 - 90 - ang)) - else: - angs = np.nan angs_root.append(angs) - return np.array(angs_root) + angs_root = np.array(angs_root) + + # If only one root, return a scalar instead of a single-element array + if angs_root.shape[0] == 1: + return angs_root[0] + return angs_root diff --git a/sleap_roots/bases.py b/sleap_roots/bases.py index a834b97..25d551a 100644 --- a/sleap_roots/bases.py +++ b/sleap_roots/bases.py @@ -1,110 +1,72 @@ """Trait calculations that rely on bases (i.e., dicot-only).""" import numpy as np -import shapely from shapely.geometry import LineString, Point from shapely.ops import nearest_points +from typing import Optional def get_bases(pts: np.ndarray, monocots: bool = False) -> np.ndarray: """Return bases (r1) from each lateral root. Args: - pts: Root landmarks as array of shape (instances, nodes, 2) + pts: Root landmarks as array of shape `(instances, nodes, 2)`. monocots: Boolean value, where false is dicot (default), true is rice. Returns: - Array of bases (instances, (x, y)). + Array of bases `(instances, (x, y))`. """ - # Get the first point of each instance. Shape is (instances, 2) if monocots: return np.nan else: - base_pts = pts[:, 0] + # Get the first point of each instance. + base_pts = pts[:, 0] # Shape is `(instances, 2)`. return base_pts -def get_root_lengths(pts: np.ndarray) -> np.ndarray: - """Return root lengths for all roots in a frame. - - Args: - pts: Root landmarks as array of shape (instances, nodes, 2). - - Returns: - Array of root lengths of shape (instances,). - If there is no root, or the roots is one point only (all of the rest of the - points are NaNs), an array of NaNs with shape (len(pts),) is returned. - This is also the case for non-contiguous points at the moment. - """ - # Get the (x,y) differences of segments for each instance. - segment_diffs = np.diff(pts, axis=1) - # Get the lengths of each segment by taking the norm. - segment_lengths = np.linalg.norm(segment_diffs, axis=-1) - # Add the segments together to get the total length using nansum. - total_lengths = np.nansum(segment_lengths, axis=-1) - # Find the NaN segment lengths and record NaN in place of 0 when finding the total - # length. - total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan - return total_lengths - - -def get_root_lengths_max(lengths: np.ndarray) -> np.ndarray: - """Return maximum root length for all roots in a frame. - - Args: - lengths: root lengths with shape of (instances,). - - Returns: - Scalar of the maximum root length. - """ - max_length = np.nanmax(lengths) - return max_length - - -def get_base_tip_dist(pts: np.ndarray) -> np.ndarray: +def get_base_tip_dist( + base_pts: np.ndarray = None, + tip_pts: np.ndarray = None, + pts: Optional[np.ndarray] = None, +) -> np.ndarray: """Return distance from root base to tip. Args: - pts: Root landmarks as array of shape (instances, nodes, 2) - - Returns: - Array of distances from base to tip of shape (instances,). - """ - base_pt = pts[:, 0] - tip_pt = pts[:, -1] - distance = np.linalg.norm(base_pt - tip_pt, axis=-1) - return distance - - -def get_grav_index(pts: np.ndarray): - """Get gravitropism index based on primary_length_max and primary_base_tip_dist. - - Args: - pts: primary root landmarks as array of shape (1, node, 2) + base_pts: Base coordinates of roots as an array of `(instances, 2)`. Not used if + `pts` is specified. + tip_pts: tips of roots `(instances, 2)`. Not used if `pts` is specified. + pts: Optional, Root landmarks as array of shape `(instances, nodes, 2)`. Returns: - Scalar of primary root gravity index. + Array of distances from base to tip of shape `(instances,)`. """ - # get primary root length, if predicted >1 primary roots, use the longest one - primary_length = get_root_lengths(pts) - primary_length_max = get_root_lengths_max(primary_length) - - # get the distance between base and tip in y axis - primary_base_tip_dist = get_base_tip_dist(pts) + if base_pts is not None and tip_pts is not None: + # If base_pts and tip_pts are provided, but they are NaN, return NaN array of + # same shape + if np.isnan(base_pts).all() or np.isnan(tip_pts).all(): + return np.full(base_pts.shape, np.nan) + # Calculate distance based on them + distance = np.linalg.norm(base_pts - tip_pts, axis=-1) + elif pts is not None: + # If pts is provided, but it is NaN, return NaN array of same shape + if np.isnan(pts).all(): + return np.full((pts.shape[0],), np.nan) + # Calculate distance based on it + base_pt = pts[:, 0] + tip_pt = pts[:, -1] + distance = np.linalg.norm(base_pt - tip_pt, axis=-1) + else: + # If neither base_pts and tip_pts nor pts is provided, raise an exception + raise ValueError("Either both base_pts and tip_pts, or pts must be provided.") - # calculate gravitropism index - pl_max = np.nanmax(primary_length_max) - if pl_max == 0: - return np.nan - grav_index = (pl_max - np.nanmax(primary_base_tip_dist)) / pl_max - return grav_index + return distance def get_lateral_count(pts: np.ndarray): """Get number of lateral roots. Args: - pts: lateral root landmarks as array of shape (instance, node, 2) + pts: lateral root landmarks as array of shape `(instance, node, 2)`. Return: Scalar of number of lateral roots. @@ -114,115 +76,240 @@ def get_lateral_count(pts: np.ndarray): def get_base_xs(pts: np.ndarray, monocots: bool = False) -> np.ndarray: - """Get x coordinations of base points. + """Get x coordinates of the base of each lateral root. Args: - pts: root landmarks as array of shape (instance, point, 2) + pts: root landmarks as array of shape `(instances, point, 2)` or bases + `(instances, 2)`. monocots: Boolean value, where false is dicot (default), true is rice. Return: - An array of bases in x axis (instance,). + An array of the x-coordinates of bases `(instance,)`. """ - _base_pts = get_bases(pts, monocots) + # If the input is a single number (float or integer), return np.nan + if isinstance(pts, (np.floating, float, np.integer, int)): + return np.nan + + # If the input array doesn't have 2 or 3 dimensions, raise an error + if pts.ndim not in (2, 3): + raise ValueError( + "Input array must be 2-dimensional (n_bases, 2) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + # If the input array has 3 dimensions, calculate the base points, + # otherwise, assume the input array already contains the base points + if pts.ndim == 3: + _base_pts = get_bases( + pts, monocots + ) # Assuming get_bases returns an array of shape (instance, 2) + else: + _base_pts = pts + + # If _base_pts is a single number (float or integer), return np.nan if isinstance(_base_pts, (np.floating, float, np.integer, int)): return np.nan + + # If the base points array doesn't have exactly 2 dimensions or + # the second dimension is not of size 2, raise an error + if _base_pts.ndim != 2 or _base_pts.shape[1] != 2: + raise ValueError( + "Array of base points must be 2-dimensional with shape (instance, 2)." + ) + + # If everything is fine, extract and return the x-coordinates of the base points else: base_xs = _base_pts[:, 0] return base_xs def get_base_ys(pts: np.ndarray, monocots: bool = False) -> np.ndarray: - """Get y coordinations of base points. + """Get y coordinates of the base of each lateral root. Args: - pts: root landmarks as array of shape (instance, point, 2) + pts: root landmarks as array of shape `(instances, point, 2)` or bases + `(instances, 2)`. monocots: Boolean value, where false is dicot (default), true is rice. Return: - An array of bases in y axis (instance,). + An array of the y-coordinates of bases (instance,). """ - _base_pts = get_bases(pts, monocots) + # If the input is a single number (float or integer), return np.nan + if isinstance(pts, (np.floating, float, np.integer, int)): + return np.nan + + # If the input array doesn't have 2 or 3 dimensions, raise an error + if pts.ndim not in (2, 3): + raise ValueError( + "Input array must be 2-dimensional (n_bases, 2) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + # If the input array has 3 dimensions, calculate the base points, + # otherwise, assume the input array already contains the base points + if pts.ndim == 3: + _base_pts = get_bases( + pts, monocots + ) # Assuming get_bases returns an array of shape (instance, 2) + else: + _base_pts = pts + + # If _base_pts is a single number (float or integer), return np.nan if isinstance(_base_pts, (np.floating, float, np.integer, int)): return np.nan + + # If the base points array doesn't have exactly 2 dimensions or + # the second dimension is not of size 2, raise an error + if _base_pts.ndim != 2 or _base_pts.shape[1] != 2: + raise ValueError( + "Array of base points must be 2-dimensional with shape (instance, 2)." + ) else: base_ys = _base_pts[:, 1] return base_ys def get_base_length(pts: np.ndarray, monocots: bool = False): - """Get lateral roots top and deepest bases distance in y axis. + """Get the y-axis difference from the top lateral base to the bottom lateral base. Args: - pts: lateral root landmarks as array of shape (instance, point, 2) + pts: root landmarks as array of shape `(instances, point, 2)` or base_ys + `(instances)`. monocots: Boolean value, where false is dicot (default), true is rice. Return: - Top and deepest bases distance y-axis. + The distance between the top base y-coordinate and the deepest + base y-coordinate. """ - base_ys = get_base_ys(pts, monocots) + # If the input is a single number (float or integer), return np.nan + if isinstance(pts, (np.floating, float, np.integer, int)): + return np.nan + + if pts.ndim not in (1, 3): + raise ValueError( + "Input array must be 1-dimensional (n_base_ys) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + if pts.ndim == 3: + base_ys = get_base_ys( + pts, monocots + ) # Assuming get_base_ys returns an array of shape (instances) + else: + base_ys = pts + + # If base_ys is a single number (float or integer), return np.nan + if isinstance(base_ys, (np.floating, float, np.integer, int)): + return np.nan + + if base_ys.ndim != 1: + raise ValueError( + "Array of base y-coordinates must be 1-dimensional with shape (instances)." + ) + base_length = np.nanmax(base_ys) - np.nanmin(base_ys) return base_length -def get_base_ct_density(primary_pts, lateral_pts, monocots: bool = False): - """Get number of base points to maximum primary root length. +def get_base_ct_density(primary_pts, lateral_pts: np.ndarray, monocots: bool = False): + """Get a ratio of the number of base points to maximum primary root length. Args: - primary_pts: primary root points - lateral_pts: lateral root points + primary_pts: primary root points of shape `(instances, nodes, 2)` or scalar + maximum primary root length. + lateral_pts: lateral root points of shape `(instances, nodes, 2)` or base points + of lateral roots of shape (instances, 2). monocots: Boolean value, where false is dicot (default), true is rice. Return: Scalar of base count density. """ - # get number of base points of lateral roots - _base_pts = get_bases(lateral_pts, monocots) - if isinstance(_base_pts, (np.floating, float, np.integer, int)): + # If the inputs are single numbers (float or integer), return np.nan + if ( + isinstance(lateral_pts, (np.floating, float, np.integer, int)) + or np.isnan(primary_pts).all() + ): return np.nan + + # If the input array has 3 dimensions, calculate the base points, + # otherwise, assume the input array already contains the base points + if lateral_pts.ndim == 3: + _base_pts = get_bases( + lateral_pts, monocots + ) # Assuming get_bases returns an array of shape (instances, 2) else: - base_ct = len(_base_pts[~np.isnan(_base_pts[:, 0])]) - # get primary root length - lengths_primary = get_root_lengths(primary_pts) - base_ct_density = base_ct / np.nanmax(lengths_primary) - return base_ct_density + _base_pts = lateral_pts + # If _base_pts is a single number (float or integer), return np.nan + if isinstance(_base_pts, (np.floating, float, np.integer, int)): + return np.nan -def get_primary_depth(primary_pts): - """Get primary root tip depth. + # Get number of base points of lateral roots + base_ct = len(_base_pts[~np.isnan(_base_pts[:, 0])]) - Args: - primary_pts: primary root points. + # Assuming primary_pts is a scalar + primary_length_max = primary_pts - Return: - Scalar of primary root tip depth. - """ - primary_depth = np.nanmax(primary_pts[:, :, 1]) - return primary_depth + # Handle case where maximum primary length is zero to avoid division by zero + if primary_length_max == 0: + return np.nan + # Handle case where primary lengths are all NaN or empty + if np.isnan(primary_length_max): + return np.nan + # Get base_ct_density + base_ct_density = base_ct / primary_length_max + return base_ct_density -def get_base_length_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray): + +def get_base_length_ratio( + primary_pts: np.ndarray, lateral_pts: np.ndarray, monocots: bool = False +): """Get ratio of top-deep base length to primary root length. Args: - primary_pts: primary root points. - lateral_pts: lateral root points. + primary_pts: primary root points of shape `(instances, nodes, 2)` or scalar + maximum primary root length. + lateral_pts: lateral root points of shape `(instances, nodes, 2)`. + or scalar of base_length. + monocots: Boolean value, where false is dicot (default), true is rice. Return: Scalar of base length ratio. """ - base_length = get_base_length(lateral_pts) - primary_length = get_root_lengths(primary_pts) - primary_length_max = get_root_lengths_max(primary_length) - if primary_length_max == 0: + # If lateral_pts is a single number (float or integer) or primary_pts is NaN, return np.nan + if ( + isinstance(lateral_pts, (np.floating, float, np.integer, int)) + or np.isnan(primary_pts).all() + ): return np.nan + + # If the input array has 3 dimensions, calculate the base length, + # otherwise, assume the input array already contains the base length + if lateral_pts.ndim == 3: + base_length = get_base_length( + lateral_pts + ) # Assuming get_base_length returns an array of shape (instances) else: - base_length_ratio = base_length / primary_length_max - return base_length_ratio + base_length = lateral_pts # Assuming lateral_pts is a scalar + primary_length_max = ( + primary_pts # Assuming primary_pts is the maximum primary root length + ) -def get_base_median_ratio( - primary_pts: np.ndarray, lateral_pts: np.ndarray, monocots: bool = False -): + # Handle case where maximum primary length is zero to avoid division by zero + if primary_length_max == 0: + return np.nan + # Handle case where primary lengths are all NaN or empty + if np.isnan(primary_length_max): + return np.nan + + # Compute and return base length ratio + base_length_ratio = base_length / primary_length_max + return base_length_ratio + + +def get_base_median_ratio(lateral_base_ys, primary_tip_pt_y, monocots: bool = False): """Get ratio of median value in all base points to tip of primary root in y axis. Args: @@ -233,114 +320,84 @@ def get_base_median_ratio( Return: Scalar of base median ratio. """ - _base_pts = get_bases(lateral_pts, monocots) - pr_tip_depth = np.nanmax(primary_pts[:, :, 1]) - if np.isnan(_base_pts).all(): + if np.isnan(lateral_base_ys).all(): return np.nan else: - base_median_ratio = np.nanmedian(_base_pts[:, 1]) / pr_tip_depth + base_median_ratio = np.nanmedian(lateral_base_ys) / primary_tip_pt_y return base_median_ratio def get_root_pair_widths_projections( - lateral_pts, primary_pts, tolerance, monocots: bool = False -): - """Return estimation of stem width using bases of lateral roots. + primary_pts: np.ndarray, + lateral_pts: np.ndarray, + tolerance: float, + monocots: bool = False, +) -> float: + """Return estimation of root width using bases of lateral roots. Args: - lateral_pts: Lateral roots as arrays of shape (n, nodes, 2). - primary_pts: longest primary root as arrays of shape (n, nodes, 2). - tolerance: difference in projection norm between the right and left side (~0.02). - monocots: Boolean value, where false is dicot (default), true is rice. + primary_pts: Longest primary root as an array of shape (n, nodes, 2). + lateral_pts: Lateral roots as an array of shape (n, nodes, 2). + tolerance: Difference in projection norm between the right and left side (~0.02). + monocots: Boolean value, where False is dicot (default), True is rice. Returns: - A match_dists is the distance in pixels between the bases of matched - roots as a vector of size (n_matches,). + float: The distance in pixels between the bases of matched roots, or NaN + if no matches were found or all input points were NaN. + Raises: + ValueError: If the input arrays are of incorrect shape. """ - if monocots: + + if primary_pts.ndim != 3 or lateral_pts.ndim != 3: + raise ValueError("Input arrays should be 3-dimensional") + + if monocots or np.isnan(primary_pts).all() or np.isnan(lateral_pts).all(): return np.nan - else: - if np.isnan(primary_pts).all(): - return np.nan - else: - primary_pts_filtered = primary_pts[~np.isnan(primary_pts).any(axis=2)] - primary_line = LineString(primary_pts_filtered) - - # Make a line of the primary points - primary_line = LineString(primary_pts_filtered) - - # Filter by whether the base node is present. - has_base = ~np.isnan(lateral_pts[:, 0, 0]) - valid_inds = np.argwhere(has_base).squeeze() - lateral_pts = lateral_pts[has_base] - - # Find roots facing left based on whether the base x-coord - # is larger than the tip x-coord. - is_left = lateral_pts[:, 0, 0] > np.nanmin(lateral_pts[:, 1:, 0], axis=1) - - # Edge Case: Only found roots on one side. - if is_left.all() or (~is_left).all(): - return np.nan - - # Get left and right base points. - left_bases, right_bases = lateral_pts[is_left, 0], lateral_pts[~is_left, 0] - - # Find the nearest point to each right lateral base on the primary root line - nearest_primary_right = [ - nearest_points(primary_line, Point(right_base))[0] - for right_base in right_bases - ] - - # Find the nearest point to each left lateral base on the primary root line - nearest_primary_left = [ - nearest_points(primary_line, Point(left_base))[0] - for left_base in left_bases - ] - - # Returns the distance along the primary line of point in nearest_primary_right, normalized to the length of the object. - nearest_primary_norm_right = np.array( - [ - primary_line.project(pt, normalized=True) - for pt in nearest_primary_right - ] - ) - # Returns the distance along the primary line of point in nearest_primary_left, normalized to the length of the object. - nearest_primary_norm_left = np.array( - [ - primary_line.project(pt, normalized=True) - for pt in nearest_primary_left - ] - ) - - # get all possible differences in projections from all base pairs - projection_diffs = np.abs( - nearest_primary_norm_left.reshape(-1, 1) - - nearest_primary_norm_right.reshape(1, -1) - ) - - # shape is [# of valid base pairs, 2 [left right]] - indices = np.argwhere(projection_diffs <= tolerance) - - left_inds = indices[:, 0] - right_inds = indices[:, 1] - - # Find pairwise distances. (shape is (# of left bases, # of right bases)) - dists = np.linalg.norm( - np.expand_dims(left_bases, axis=1) - - np.expand_dims(right_bases, axis=0), - axis=-1, - ) - - # Pull out match distances. - match_dists = np.array([dists[l, r] for l, r in zip(left_inds, right_inds)]) - - # Convert matches to indices before splitting by side. - left_inds = np.argwhere(is_left).reshape(-1)[left_inds] - right_inds = np.argwhere(~is_left).reshape(-1)[right_inds] - - # Convert matches to indices before filtering. - left_inds = valid_inds[left_inds] - right_inds = valid_inds[right_inds] - - return match_dists + + primary_pts_filtered = primary_pts[~np.isnan(primary_pts).any(axis=2)] + primary_line = LineString(primary_pts_filtered) + + has_base = ~np.isnan(lateral_pts[:, 0, 0]) + valid_inds = np.argwhere(has_base).squeeze() + lateral_pts = lateral_pts[has_base] + + is_left = lateral_pts[:, 0, 0] > np.nanmin(lateral_pts[:, 1:, 0], axis=1) + + if is_left.all() or (~is_left).all(): + return np.nan + + left_bases, right_bases = lateral_pts[is_left, 0], lateral_pts[~is_left, 0] + + nearest_primary_right = [ + nearest_points(primary_line, Point(right_base))[0] for right_base in right_bases + ] + + nearest_primary_left = [ + nearest_points(primary_line, Point(left_base))[0] for left_base in left_bases + ] + + nearest_primary_norm_right = np.array( + [primary_line.project(pt, normalized=True) for pt in nearest_primary_right] + ) + + nearest_primary_norm_left = np.array( + [primary_line.project(pt, normalized=True) for pt in nearest_primary_left] + ) + + projection_diffs = np.abs( + nearest_primary_norm_left.reshape(-1, 1) + - nearest_primary_norm_right.reshape(1, -1) + ) + + indices = np.argwhere(projection_diffs <= tolerance) + + left_inds = indices[:, 0] + right_inds = indices[:, 1] + + match_dists = np.linalg.norm( + left_bases[left_inds] - right_bases[right_inds], + axis=-1, + ) + + return match_dists diff --git a/sleap_roots/graphpipeline.py b/sleap_roots/graphpipeline.py deleted file mode 100644 index a6e3200..0000000 --- a/sleap_roots/graphpipeline.py +++ /dev/null @@ -1,712 +0,0 @@ -"""Extract traits based on the networkx graph.""" - -import numpy as np -import pandas as pd -import os -from typing import List -from fractions import Fraction -from pathlib import Path -from sleap_roots.traitsgraph import get_traits_graph -from sleap_roots.angle import get_root_angle -from sleap_roots.bases import ( - get_bases, - get_base_ct_density, - get_base_length, - get_base_length_ratio, - get_base_median_ratio, - get_base_tip_dist, - get_base_xs, - get_base_ys, - get_grav_index, - get_lateral_count, - get_primary_depth, - get_root_lengths, - get_root_pair_widths_projections, -) -from sleap_roots.convhull import ( - get_chull_area, - get_chull_line_lengths, - get_chull_max_width, - get_chull_max_height, - get_chull_perimeter, - get_convhull_features, -) -from sleap_roots.ellipse import ( - fit_ellipse, - get_ellipse_a, - get_ellipse_b, - get_ellipse_ratio, -) -from sleap_roots.networklength import ( - get_bbox, - get_network_distribution_ratio, - get_network_distribution, - get_network_solidity, - get_network_width_depth_ratio, -) -from sleap_roots.points import get_all_pts_array, get_all_pts -from sleap_roots.scanline import ( - count_scanline_intersections, - get_scanline_first_ind, - get_scanline_last_ind, -) -from sleap_roots.series import Series, find_all_series -from sleap_roots.summary import get_summary -from sleap_roots.tips import get_tips, get_tip_xs, get_tip_ys -from typing import Dict, Tuple -import warnings - - -SCALAR_TRAITS = ( - "primary_angle_proximal", - "primary_angle_distal", - "primary_length", - "primary_base_tip_dist", - "primary_depth", - "lateral_count", - "grav_index", - "base_length", - "base_length_ratio", - "primary_tip_pt_y", - "base_median_ratio", - "base_ct_density", - "chull_perimeter", - "chull_area", - "chull_max_width", - "chull_max_height", - "ellipse_a", - "ellipse_b", - "ellipse_ratio", - "network_width_depth_ratio", - "network_solidity", - "network_length_lower", - "network_distribution_ratio", - "scanline_first_ind", - "scanline_last_ind", -) - -NON_SCALAR_TRAITS = ( - "lateral_angles_proximal", - "lateral_angles_distal", - "lateral_lengths", - "stem_widths", - "lateral_base_xs", - "lateral_base_ys", - "lateral_tip_xs", - "lateral_tip_ys", - "chull_line_lengths", - "scanline_intersection_counts", -) - - -warnings.filterwarnings( - "ignore", - message="invalid value encountered in intersection", - category=RuntimeWarning, - module="shapely", -) -warnings.filterwarnings( - "ignore", message="All-NaN slice encountered", category=RuntimeWarning -) -warnings.filterwarnings( - "ignore", message="All-NaN axis encountered", category=RuntimeWarning -) -warnings.filterwarnings( - "ignore", - message="Degrees of freedom <= 0 for slice.", - category=RuntimeWarning, - module="numpy", -) -warnings.filterwarnings( - "ignore", message="Mean of empty slice", category=RuntimeWarning -) -warnings.filterwarnings( - "ignore", - message="invalid value encountered in sqrt", - category=RuntimeWarning, - module="skimage", -) -warnings.filterwarnings( - "ignore", - message="invalid value encountered in double_scalars", - category=RuntimeWarning, -) -warnings.filterwarnings( - "ignore", - message="invalid value encountered in scalar divide", - category=RuntimeWarning, - module="ellipse", -) - - -def get_traits_value_frame( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - pts_all_array: np.ndarray, - pts_all_list: list, - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: float = 2 / 3, - monocots: bool = False, -) -> Dict: - """Get SLEAP traits per frame based on graph. - - Args: - primary_pts: primary points - lateral_pts: lateral points - pts_all_array: all points in array format - pts_all_list: all points in list format - stem_width_tolerance: difference in projection norm between right and left side. - n_line: number of scan lines, np.nan for no interaction. - network_fraction: length found in the lower fration value of the network. - monocots: Boolean value, where false is dicot (default), true is rice. - - Return: - A dictionary with all traits per frame. - """ - trait_map = { - # get_bases(pts: np.ndarray,monocots) -> np.ndarray - "primary_base_pt": (get_bases, [primary_pts, monocots]), - # get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray - "primary_angle_proximal": (get_root_angle, [primary_pts, True, 0]), - "primary_angle_distal": (get_root_angle, [primary_pts, False, 0]), - # get_root_lengths(pts: np.ndarray) -> np.ndarray - "primary_length": (get_root_lengths, [primary_pts]), - # get_tips(pts) - "primary_tip_pt": (get_tips, [primary_pts]), - # fit_ellipse(pts: np.ndarray) -> Tuple[float, float, float] - "ellipse": (fit_ellipse, [pts_all_array]), - # get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float] - "bounding_box": (get_bbox, [pts_all_array]), - # get_root_pair_widths_projections(lateral_pts, primary_pts, tolerance,monocots) - "stem_widths": ( - get_root_pair_widths_projections, - [lateral_pts, primary_pts, stem_width_tolerance, monocots], - ), - # get_convhull_features(pts: Union[np.ndarray, ConvexHull]) -> Tuple[float, float, float, float] - "convex_hull": (get_convhull_features, [pts_all_array]), - # get_lateral_count(pts: np.ndarray) - "lateral_count": (get_lateral_count, [lateral_pts]), - # # get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray - "lateral_angles_proximal": (get_root_angle, [lateral_pts, True, 0]), - "lateral_angles_distal": (get_root_angle, [lateral_pts, False, 0]), - # get_root_lengths(pts: np.ndarray) -> np.ndarray - "lateral_lengths": (get_root_lengths, [lateral_pts]), - # get_bases(pts: np.ndarray,monocots) -> np.ndarray - "lateral_base_pts": (get_bases, [lateral_pts, monocots]), - # get_tips(pts) - "lateral_tip_pts": (get_tips, [lateral_pts]), - # get_base_ys(pts: np.ndarray) -> np.ndarray - # or just based on primary_base_pt, but the primary_base_pt trait must generate before - # "primary_base_pt_y": (get_pt_ys, [data["primary_base_pt"]]), - "primary_base_pt_y": (get_base_ys, [primary_pts]), - # get_base_ct_density(primary_pts, lateral_pts) - "base_ct_density": (get_base_ct_density, [primary_pts, lateral_pts, monocots]), - # get_network_solidity(primary_pts: np.ndarray, lateral_pts: np.ndarray, pts_all_array: np.ndarray, monocots: bool = False,) -> float - "network_solidity": ( - get_network_solidity, - [primary_pts, lateral_pts, pts_all_array, monocots], - ), - # get_network_distribution_ratio(primary_pts: np.ndarray,lateral_pts: np.ndarray,pts_all_array: np.ndarray,fraction: float = 2 / 3, monocots: bool = False) -> float: - "network_distribution_ratio": ( - get_network_distribution_ratio, - [primary_pts, lateral_pts, pts_all_array, network_fraction, monocots], - ), - # get_network_distribution(primary_pts: np.ndarray,lateral_pts: np.ndarray,pts_all_array: np.ndarray,fraction: float = 2 / 3, monocots: bool = False) -> float: - "network_length_lower": ( - get_network_distribution, - [primary_pts, lateral_pts, pts_all_array, network_fraction, monocots], - ), - # get_tip_ys(pts: np.ndarray) -> np.ndarray - "primary_tip_pt_y": (get_tip_ys, [primary_pts]), - # get_ellipse_a(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) - "ellipse_a": (get_ellipse_a, [pts_all_array]), - # get_ellipse_b(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) - "ellipse_b": (get_ellipse_b, [pts_all_array]), - # get_network_width_depth_ratio(pts: np.ndarray) -> float - "network_width_depth_ratio": (get_network_width_depth_ratio, [pts_all_array]), - # get_chull_perimeter(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_perimeter": (get_chull_perimeter, [pts_all_array]), - # get_chull_area(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_area": (get_chull_area, [pts_all_array]), - # get_chull_max_width(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_max_width": (get_chull_max_width, [pts_all_array]), - # get_chull_max_height(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) - "chull_max_height": (get_chull_max_height, [pts_all_array]), - # get_chull_line_lengths(pts: Union[np.ndarray, ConvexHull]) -> np.ndarray - "chull_line_lengths": (get_chull_line_lengths, [pts_all_array]), - # count_scanline_intersections(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080,width: int = 2048,n_line: int = 50,monocots: bool = False,) -> np.ndarray - "scanline_intersection_counts": ( - count_scanline_intersections, - [primary_pts, lateral_pts, 1080, 2048, 50, monocots], - ), - # get_base_xs(pts: np.ndarray) -> np.ndarray - "lateral_base_xs": (get_base_xs, [lateral_pts, monocots]), - # get_base_ys(pts: np.ndarray) -> np.ndarray - "lateral_base_ys": (get_base_ys, [lateral_pts, monocots]), - # get_tip_xs(pts: np.ndarray) -> np.ndarray - "lateral_tip_xs": (get_tip_xs, [lateral_pts]), - # get_tip_ys(pts: np.ndarray) -> np.ndarray - "lateral_tip_ys": (get_tip_ys, [lateral_pts]), - # get_base_tip_dist(pts: np.ndarray) -> np.ndarray - "primary_base_tip_dist": (get_base_tip_dist, [primary_pts]), - # get_primary_depth(primary_pts) - "primary_depth": (get_primary_depth, [primary_pts]), - # get_base_median_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray) - "base_median_ratio": ( - get_base_median_ratio, - [primary_pts, lateral_pts, monocots], - ), - # get_ellipse_ratio(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) - "ellipse_ratio": (get_ellipse_ratio, [pts_all_array]), - # get_scanline_last_ind(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False) - "scanline_last_ind": ( - get_scanline_last_ind, - [primary_pts, lateral_pts, 1080, 2048, n_line, monocots], - ), - # get_scanline_first_ind(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False) - "scanline_first_ind": ( - get_scanline_first_ind, - [primary_pts, lateral_pts, 1080, 2048, n_line, monocots], - ), - # get_base_length(pts: np.ndarray) - "base_length": (get_base_length, [lateral_pts, monocots]), - # get_grav_index(pts: np.ndarray) - "grav_index": (get_grav_index, [primary_pts]), - # get_base_length_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray) - "base_length_ratio": (get_base_length_ratio, [primary_pts, lateral_pts]), - } - - dts = get_traits_graph() - - data = {} - for trait_name in dts: - fn, inputs = trait_map[trait_name] - fn_outputs = fn(*[input_trait for input_trait in inputs]) - if type(fn_outputs) == tuple: - fn_outputs = np.array(fn_outputs).reshape((1, -1)) - if isinstance(fn_outputs, (np.floating, float)) or isinstance( - fn_outputs, (np.integer, int) - ): - fn_outputs = np.array(fn_outputs)[np.newaxis] - data[trait_name] = fn_outputs - return data - - -def get_traits_value_plant( - h5, - monocots: bool = False, - primary_name: str = "primary_multi_day", - lateral_name: str = "lateral_3_nodes", - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: float = 2 / 3, - write_csv: bool = False, - csv_suffix: str = ".traits.csv", -) -> Tuple[Dict, pd.DataFrame, str]: - """Get detailed SLEAP traits for every frame of a plant, based on the graph. - - Args: - h5: The h5 file representing the plant image series. - monocots: A boolean value indicating whether the plant is a monocot (True) - or a dicot (False) (default). - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. - stem_width_tolerance: The difference in the projection norm between - the right and left side of the stem. - n_line: The number of scan lines. Use np.nan for no interaction. - network_fraction: The length found in the lower fraction value of the network. - write_csv: A boolean value. If True, it writes per plant detailed - CSVs with traits for every instance on every frame. - csv_suffix: If write_csv=True, the CSV file will be saved with the - h5 path + csv_suffix. - - Returns: - A tuple containing a dictionary and a DataFrame with all traits per plant, - and the plant name. The Dataframe has root traits per instance and frame - where each row corresponds to a frame in the H5 file. The plant_name is - given by the h5 file. - """ - plant = Series.load(h5, primary_name=primary_name, lateral_name=lateral_name) - plant_name = plant.series_name - # get number of frames per plant - n_frame = len(plant) - - data_plant = [] - # get traits for each frames in a row - for frame in range(n_frame): - primary, lateral = plant[frame] - - gt_instances_pr = primary.user_instances + primary.unused_predictions - gt_instances_lr = lateral.user_instances + lateral.unused_predictions - - if len(gt_instances_lr) == 0: - lateral_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - lateral_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) - - if len(gt_instances_pr) == 0: - primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) - - pts_all_array = get_all_pts_array(plant=plant, frame=frame, monocots=False) - if len(pts_all_array) == 0: - pts_all_array = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - pts_all_list = [] - - if get_root_lengths(primary_pts).shape[0] > 0 and not len(gt_instances_pr) == 0: - max_length_idx = np.nanargmax(get_root_lengths(primary_pts)) - long_primary_pts = primary_pts[max_length_idx] - primary_pts = np.reshape( - long_primary_pts, - (1, long_primary_pts.shape[0], long_primary_pts.shape[1]), - ) - else: - # if no primary root, just give two nan points - primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - - data = get_traits_value_frame( - primary_pts, - lateral_pts, - pts_all_array, - pts_all_list, - stem_width_tolerance, - n_line, - network_fraction, - monocots, - ) - - data["plant_name"] = plant_name - data["frame_idx"] = frame - data_plant.append(data) - data_plant_df = pd.DataFrame(data_plant) - - # reorganize the column position - column_names = data_plant_df.columns.tolist() - column_names = [column_names[-2]] + [column_names[-1]] + column_names[:-2] - data_plant_df = data_plant_df[column_names] - - # convert the data in scalar column to the value without [] - columns_to_convert = data_plant_df.columns[ - data_plant_df.apply( - lambda x: all( - isinstance(val, np.ndarray) and val.shape == (1,) for val in x - ) - ) - ] - data_plant_df[columns_to_convert] = data_plant_df[columns_to_convert].apply( - lambda x: x.apply(lambda val: val[0]) - ) - - if write_csv: - csv_name = Path(h5).with_suffix(f"{csv_suffix}") - data_plant_df.to_csv(csv_name, index=False) - return data_plant, data_plant_df, plant_name - - -def get_traits_value_plant_summary( - h5, - monocots: bool = False, - primary_name: str = "longest_3do_6nodes", - lateral_name: str = "main_3do_6nodes", - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: float = 2 / 3, - write_csv: bool = False, - csv_suffix: str = ".traits.csv", - write_summary_csv: bool = False, - summary_csv_suffix: str = ".summary_traits.csv", -) -> pd.DataFrame: - """Get summary statistics of SLEAP traits per plant based on graph. - - Args: - h5: The h5 file representing the plant image series. - monocots: A boolean value indicating whether the plant is a monocot (True) - or a dicot (False) (default). - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. - stem_width_tolerance: The difference in the projection norm between - the right and left side of the stem. - n_line: The number of scan lines. Use np.nan for no interaction. - network_fraction: The length found in the lower fraction value of the network. - write_csv: A boolean value. If True, it writes per plant detailed - CSVs with traits for every instance on every frame. - csv_suffix: If write_csv=True, the CSV file will be saved with the name - h5 path + csv_suffix. - write_summary_csv: Boolean value, where true is write summarized csv file. - summary_csv_suffix: If write_summary_csv=True, the CSV file with the summary - statistics per plant will be saved with the name - h5 path + summary_csv_suffix. - - Return: - A DataFrame with summary statistics of all traits per plant. - """ - data_plant, data_plant_df, plant_name = get_traits_value_plant( - h5, - monocots, - primary_name, - lateral_name, - stem_width_tolerance, - n_line, - network_fraction, - write_csv, - csv_suffix, - ) - - # get summarized non-scalar traits per frame - data_plant_frame_summary = [] - data_plant_frame_summary_non_scalar = {} - - for i in range(len(NON_SCALAR_TRAITS)): - trait = data_plant_df[NON_SCALAR_TRAITS[i]] - - if not trait.isna().all(): - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmin" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmin(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmax" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmax(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmean" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmean(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmedian" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanmedian(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fstd" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nanstd(x) if len(x) > 0 else np.nan) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc5" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else (np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 5)) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc25" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else ( - np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 25) - ) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc75" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else ( - np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 75) - ) - ) - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc95" - ] = trait.apply( - lambda x: x - if isinstance(x, (np.floating, float, np.integer, int)) - else ( - np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 95) - ) - ) - else: - data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fmin"] = np.nan - data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fmax"] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmean" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fmedian" - ] = np.nan - data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fstd"] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc5" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc25" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc75" - ] = np.nan - data_plant_frame_summary_non_scalar[ - NON_SCALAR_TRAITS[i] + "_fprc95" - ] = np.nan - - # get summarized scalar traits per plant - column_names = data_plant_df.columns.tolist() - data_plant_frame_summary = {} - for i in range(len(SCALAR_TRAITS)): - if SCALAR_TRAITS[i] in column_names: - trait = data_plant_df[SCALAR_TRAITS[i]] - if trait.shape[0] > 0: - if not ( - isinstance(trait[0], (np.floating, float)) - or isinstance(trait[0], (np.integer, int)) - ): - values = np.array([element[0] for element in trait]) - trait = values - trait = trait.astype(float) - trait = np.reshape(trait, (len(trait), 1)) - ( - trait_min, - trait_max, - trait_mean, - trait_median, - trait_std, - trait_prc5, - trait_prc25, - trait_prc75, - trait_prc95, - ) = get_summary(trait) - - data_plant_frame_summary[SCALAR_TRAITS[i] + "_min"] = trait_min - data_plant_frame_summary[SCALAR_TRAITS[i] + "_max"] = trait_max - data_plant_frame_summary[SCALAR_TRAITS[i] + "_mean"] = trait_mean - data_plant_frame_summary[SCALAR_TRAITS[i] + "_median"] = trait_median - data_plant_frame_summary[SCALAR_TRAITS[i] + "_std"] = trait_std - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc5"] = trait_prc5 - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc25"] = trait_prc25 - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc75"] = trait_prc75 - data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc95"] = trait_prc95 - - # append the summarized non-scalar traits per plant - data_plant_frame_summary_key = list(data_plant_frame_summary_non_scalar.keys()) - for j in range(len(data_plant_frame_summary_non_scalar)): - trait = data_plant_frame_summary_non_scalar[data_plant_frame_summary_key[j]] - ( - trait_min, - trait_max, - trait_mean, - trait_median, - trait_std, - trait_prc5, - trait_prc25, - trait_prc75, - trait_prc95, - ) = get_summary(trait) - - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_min"] = trait_min - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_max"] = trait_max - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_mean"] = trait_mean - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_median" - ] = trait_median - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_std"] = trait_std - data_plant_frame_summary[data_plant_frame_summary_key[j] + "_prc5"] = trait_prc5 - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_prc25" - ] = trait_prc25 - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_prc75" - ] = trait_prc75 - data_plant_frame_summary[ - data_plant_frame_summary_key[j] + "_prc95" - ] = trait_prc95 - data_plant_frame_summary["plant_name"] = [plant_name] - data_plant_frame_summary_df = pd.DataFrame(data_plant_frame_summary) - - # reorganize the column position - column_names = data_plant_frame_summary_df.columns.tolist() - column_names = [column_names[-1]] + column_names[:-1] - data_plant_frame_summary_df = data_plant_frame_summary_df[column_names] - - if write_summary_csv: - summary_csv_name = Path(h5).with_suffix(f"{summary_csv_suffix}") - data_plant_frame_summary_df.to_csv(summary_csv_name, index=False) - return data_plant_frame_summary_df - - -def get_all_plants_traits( - data_folders: List[str], - primary_name: str, - lateral_name: str, - stem_width_tolerance: float = 0.02, - n_line: int = 50, - network_fraction: Fraction = Fraction(2, 3), - write_per_plant_details: bool = False, - per_plant_details_csv_suffix: str = ".traits.csv", - write_per_plant_summary: bool = False, - per_plant_summary_csv_suffix: str = ".summary_traits.csv", - monocots: bool = False, - all_plants_csv_name: str = "all_plants_traits.csv", -) -> pd.DataFrame: - """Get a DataFrame with summary traits from all plants in the given data folders. - - Args: - h5: The h5 file representing the plant image series. - monocots: A boolean value indicating whether the plant is a monocot (True) - or a dicot (False) (default). - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. - stem_width_tolerance: The difference in the projection norm between - the right and left side of the stem. - n_line: The number of scan lines. Use np.nan for no interaction. - network_fraction: The length found in the lower fraction value of the network. - write_per_plant_details: A boolean value. If True, it writes per plant detailed - CSVs with traits for every instance. - per_plant_details_csv_suffix: If write_csv=True, the CSV file will be saved - with the name h5 path + csv_suffix. - write_per_plant_summary: A boolean value. If True, it writes per plant summary - CSVs. - per_plant_summary_csv_suffix: If write_summary_csv=True, the CSV file with the - summary statistics per plant will be saved with the name - h5 path + summary_csv_suffix. - all_plants_csv_name: The name of the output CSV file containing all plants' - summary traits. - - Returns: - A pandas DataFrame with summary root traits for all plants in the data folders. - Each row is a sample. - """ - h5_series = find_all_series(data_folders) - - all_traits = [] - for h5 in h5_series: - plant_traits = get_traits_value_plant_summary( - h5, - monocots=monocots, - primary_name=primary_name, - lateral_name=lateral_name, - stem_width_tolerance=stem_width_tolerance, - n_line=n_line, - network_fraction=network_fraction, - write_csv=write_per_plant_details, - csv_suffix=per_plant_details_csv_suffix, - write_summary_csv=write_per_plant_summary, - summary_csv_suffix=per_plant_summary_csv_suffix, - ) - plant_traits["path"] = h5 - all_traits.append(plant_traits) - - all_traits_df = pd.concat(all_traits, ignore_index=True) - - all_traits_df.to_csv(all_plants_csv_name, index=False) - return all_traits_df diff --git a/sleap_roots/lengths.py b/sleap_roots/lengths.py new file mode 100644 index 0000000..a5f2f91 --- /dev/null +++ b/sleap_roots/lengths.py @@ -0,0 +1,152 @@ +"""Get length-related traits""" +import numpy as np +from sleap_roots.bases import get_base_tip_dist +from typing import Optional + + +def get_max_length_pts(pts: np.ndarray) -> np.ndarray: + """ + Points of the root with maximum length (intended for primary root traits). + + Args: + pts (np.ndarray): Root landmarks as array of shape `(instances, nodes, 2)`. + + Returns: + np.ndarray: Array of points with shape `(nodes, 2)` from the root with maximum + length. + """ + # Return NaN points if the input array is empty + if len(pts) == 0: + return np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + + # Check if pts has the correct shape, raise error if it does not + if pts.ndim != 3 or pts.shape[2] != 2: + raise ValueError("Input array should have shape (instances, nodes, 2)") + + # Calculate the differences between consecutive points in each root + segment_diffs = np.diff(pts, axis=1) + + # Calculate the length of each segment (the Euclidean distance between consecutive + # points) + segment_lengths = np.linalg.norm(segment_diffs, axis=-1) + + # Sum the lengths of the segments for each root + total_lengths = np.nansum(segment_lengths, axis=-1) + + # Handle roots where all segment lengths are NaN, recording NaN in place of the + # total length for these roots + total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan + + # Return NaN points if all total lengths are NaN + if np.isnan(total_lengths).all(): + return np.array([[np.nan, np.nan]]) + + # Find the index of the root with the maximum total length + max_length_idx = np.nanargmax(total_lengths) + + # Return the points of the root with this index + return pts[max_length_idx] + + +def get_root_lengths(pts: np.ndarray) -> np.ndarray: + """Return root lengths for all roots in a frame. + + Args: + pts: Root landmarks as array of shape `(instances, nodes, 2)`. + + Returns: + Array of root lengths of shape `(instances,)`. If there is no root, or the root + is one point only (all of the rest of the points are NaNs), an array of NaNs + with shape (len(pts),) is returned. This is also the case for non-contiguous + points. + """ + # Get the (x,y) differences of segments for each instance. + segment_diffs = np.diff(pts, axis=1) + # Get the lengths of each segment by taking the norm. + segment_lengths = np.linalg.norm(segment_diffs, axis=-1) + # Add the segments together to get the total length using nansum. + total_lengths = np.nansum(segment_lengths, axis=-1) + # Find the NaN segment lengths and record NaN in place of 0 when finding the total + # length. + total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan + return total_lengths + + +def get_root_lengths_max(pts: np.ndarray) -> np.ndarray: + """Return maximum root length for all roots in a frame. + + Args: + pts: root landmarks as array of shape `(instance, nodes, 2)` or lengths + `(instances)`. + + Returns: + Scalar of the maximum root length. + """ + # If the pts are NaNs, return NaN + if np.isnan(pts).all(): + return np.nan + + if pts.ndim not in (1, 3): + raise ValueError( + "Input array must be 1-dimensional (n_lengths) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + # If the input array has 3 dimensions, calculate the root lengths, + # otherwise, assume the input array already contains the root lengths + if pts.ndim == 3: + root_lengths = get_root_lengths( + pts + ) # Assuming get_root_lengths returns an array of shape (instances) + max_length = np.nanmax(root_lengths) + else: + max_length = np.nanmax(pts) + + return max_length + + +def get_grav_index( + primary_length: float = None, + primary_base_tip_dist: float = None, + pts: Optional[np.ndarray] = None, +): + """Get gravitropism index based on primary_length_max and primary_base_tip_dist. + + Args: + primary_base_tip_dist: scalar of distance from base to tip of primary root + (longest primary root prediction used if there is more than one). Not used + if `pts` is specified. + primary_length: scalar of length of primary root (longest primary root + prediction used if there is more than one). Not used if `pts` is specified. + pts: Optional, primary root landmarks as array of shape `(instances, nodes, 2)`. + + Returns: + Scalar of primary root gravity index. + """ + if primary_length is not None and primary_base_tip_dist is not None: + # If primary_length and primary_base_tip_dist are provided, use them + if np.isnan(primary_length) or np.isnan(primary_base_tip_dist): + return np.nan + pl_max = primary_length + primary_base_tip_dist_max = primary_base_tip_dist + elif pts is not None: + # If pts is provided, calculate lengths and base-tip distances based on it + if np.isnan(pts).all(): + return np.nan + primary_length_max = get_root_lengths_max(pts=pts) + primary_base_tip_dist = get_base_tip_dist(pts=pts) + pl_max = np.nanmax(primary_length_max) + primary_base_tip_dist_max = np.nanmax(primary_base_tip_dist) + else: + # If neither primary_length and primary_base_tip_dist nor pts is provided, raise an exception + raise ValueError( + "Either both primary_length and primary_base_tip_dist, or pts must be provided." + ) + # Check if pl_max or primary_base_tip_dist_max is NaN, if so, return NaN + if np.isnan(pl_max) or np.isnan(primary_base_tip_dist_max): + return np.nan + # calculate gravitropism index + if pl_max == 0: + return np.nan + grav_index = (pl_max - primary_base_tip_dist_max) / pl_max + return grav_index diff --git a/sleap_roots/networklength.py b/sleap_roots/networklength.py index 1bda1dc..6b37a31 100644 --- a/sleap_roots/networklength.py +++ b/sleap_roots/networklength.py @@ -2,9 +2,9 @@ import numpy as np from shapely import LineString, Polygon -from sleap_roots.bases import get_root_lengths +from sleap_roots.lengths import get_root_lengths from sleap_roots.convhull import get_convhull_features -from typing import Tuple +from typing import Optional, Tuple, Union def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: @@ -36,17 +36,23 @@ def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: return bbox -def get_network_width_depth_ratio(pts: np.ndarray) -> float: +def get_network_width_depth_ratio( + pts: Union[np.ndarray, Tuple[float, float, float, float]] +) -> float: """Return width to depth ratio of bounding box for root network. Args: - pts: Root landmarks as array of shape (..., 2). + pts: Root landmarks as array of shape (..., 2) or boundary box. + bbox: Optional, the bounding box of all root landmarks. Returns: Float of bounding box width to depth ratio of root network. """ # get the bounding box - bbox = get_bbox(pts) + if type(pts) == tuple: + bbox = pts + else: + bbox = get_bbox(pts) width, height = bbox[2], bbox[3] if width > 0 and height > 0: ratio = width / height @@ -55,32 +61,84 @@ def get_network_width_depth_ratio(pts: np.ndarray) -> float: return np.nan -def get_network_solidity( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - pts_all_array: np.ndarray, +def get_network_length( + primary_length: Union[float, np.ndarray], + lateral_lengths: Union[float, np.ndarray], monocots: bool = False, ) -> float: - """Return the total network length divided by the network convex area. + """Return all primary or lateral root length one frame. Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - pts_all_array: primary and lateral root landmarks. + primary_length: primary root length or maximum length primary root landmarks as + array of shape `(node, 2)`. + lateral_lengths: lateral root length or lateral root landmarks as array of shape + `(instance, node, 2)`. monocots: a boolean value, where True is rice. Returns: - Float of the total network length divided by the network convex area. + Float of all roots network length. """ - # get the total network length - network_length = get_network_length(primary_pts, lateral_pts, monocots) + # check whether primary_length is the maximum length or maximum primary root + if not ( + isinstance(primary_length, (float, np.float64)) or primary_length.ndim != 2 + ): + raise ValueError( + "Input primary_length should be the maximum primary root " + "length or array have shape (nodes, 2)." + ) + # get primary_root_length + primary_root_length = ( + primary_length + if not isinstance(primary_length, np.ndarray) + else get_root_lengths(primary_length) + ) + + # check whether lateral_lengths is the lengths or lateral root nodex. + if not ( + isinstance(lateral_lengths, (float, np.float64)) # length with only one root + or lateral_lengths.ndim != 1 # lenthgs with more than one lateral roots + or lateral_lengths.ndim != 3 # lateral root nodes + ): + raise ValueError( + "Input lateral_lengths should be the lateral root lengths or array have " + "shape (instance, nodes, 2)." + ) + + # get lateral_root_length + if lateral_lengths.ndim != 3: # lateral root nodes + lateral_root_length = np.sum(get_root_lengths(lateral_lengths)) + elif lateral_lengths.ndim != 1: # lenthgs with more than one lateral roots + lateral_root_length = np.sum(lateral_lengths) + else: # length with only one lateral root + lateral_root_length = lateral_lengths + + # return Nan if lengths less than 0 + if primary_root_length + lateral_root_length < 0: + return np.nan + + if monocots: + length = lateral_root_length + else: + length = primary_root_length + lateral_root_length + + return length + + +def get_network_solidity( + network_length: float, + chull_area: float, +) -> float: + """Return the total network length divided by the network convex area. - # get the convex hull area - convhull_features = get_convhull_features(pts_all_array) - conv_area = convhull_features[1] + Args: + network_length: all root lengths. + chull_area: an optional argument of convex hull area. - if network_length > 0 and conv_area > 0: - ratio = network_length / conv_area + Returns: + Float of the total network length divided by the network convex area. + """ + if network_length > 0 and chull_area > 0: + ratio = network_length / chull_area return ratio else: return np.nan @@ -89,7 +147,7 @@ def get_network_solidity( def get_network_distribution( primary_pts: np.ndarray, lateral_pts: np.ndarray, - pts_all_array: np.ndarray, + pts_all_array: Union[np.ndarray, Tuple[float, float, float, float]], fraction: float = 2 / 3, monocots: bool = False, ) -> float: @@ -98,7 +156,7 @@ def get_network_distribution( Args: primary_pts: primary root landmarks as array of shape (..., 2). lateral_pts: lateral root landmarks as array of shape (..., 2). - pts_all_array: primary and lateral root landmarks. + pts_all_array: primary and lateral root landmarks or the boundary box. fraction: the network length found in the lower fration value of the network. monocots: a boolean value, where True is rice. @@ -106,7 +164,10 @@ def get_network_distribution( Float of the root network length in the lower fraction of the plant. """ # get the bounding box - bbox = get_bbox(pts_all_array) + if type(pts_all_array) == tuple: + bbox = pts_all_array + else: + bbox = get_bbox(pts_all_array) left_x, top_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3] # get the bounding box of the lower fraction @@ -115,7 +176,7 @@ def get_network_distribution( return np.nan lower_bbox = (bbox[0], bbox[1] + (bbox[3] - lower_height), bbox[2], lower_height) - # convert bounding box to polygon + # convert lower bounding box to polygon polygon = Polygon( [ [bbox[0], bbox[1] + (bbox[3] - lower_height)], @@ -127,13 +188,9 @@ def get_network_distribution( # filter out the nan nodes if monocots: - points = list(primary_pts) + points = list(lateral_pts) else: points = list(primary_pts) + list(lateral_pts) - pts_nnan = [] - for j in range(len(points)): - pts_j = points[j][~np.isnan(points[j]).any(axis=1)] - pts_nnan.append(pts_j) # get length of lines within the lower bounding box root_length = 0 @@ -151,49 +208,22 @@ def get_network_distribution( return root_length -def get_network_length( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - monocots: bool = False, -) -> float: - """Return all primary or lateral root length one frame. - - Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - monocots: a boolean value, where True is rice. - - Returns: - Float of primary or lateral root network length. - """ - if ( - np.sum(get_root_lengths(primary_pts)) > 0 - or np.sum(get_root_lengths(lateral_pts)) > 0 - ): - if monocots: - length = np.nansum(get_root_lengths(primary_pts)) - else: - length = np.nansum(get_root_lengths(primary_pts)) + np.nansum( - get_root_lengths(lateral_pts) - ) - return length - else: - return 0 - - def get_network_distribution_ratio( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - pts_all_array: np.ndarray, + primary_length: Union[float, np.ndarray], + lateral_lengths: Union[float, np.ndarray], + network_length_lower: Union[float, np.ndarray], fraction: float = 2 / 3, monocots: bool = False, ) -> float: """Return ratio of the root length in the lower fraction over all root length. Args: - primary_pts: primary root landmarks as array of shape (..., 2). - lateral_pts: lateral root landmarks as array of shape (..., 2). - pts_all_array: primary and lateral root landmarks. + primary_length: primary root length or maximum length primary root landmarks as + array of shape `(node, 2)`. + lateral_lengths: lateral root length or lateral root landmarks as array of shape + `(instance, node, 2)`. + network_length_lower: the root length in lower network or primary and lateral + root landmarks. fraction: the network length found in the lower fration value of the network. monocots: a boolean value, where True is rice. @@ -201,21 +231,64 @@ def get_network_distribution_ratio( Float of ratio of the root network length in the lower fraction of the plant over all root length. """ - if ( - np.sum(get_root_lengths(primary_pts)) + np.sum(get_root_lengths(lateral_pts)) - > 0 + # check whether primary_length is the maximum length or maximum primary root + if not ( + isinstance(primary_length, (float, np.float64)) or primary_length.ndim != 2 ): - if monocots: - ratio = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction, monocots - ) / (np.sum(get_root_lengths(primary_pts))) - else: - ratio = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction - ) / ( - np.sum(get_root_lengths(primary_pts)) - + np.sum(get_root_lengths(lateral_pts)) - ) - return ratio + raise ValueError( + "Input primary_length should be the maximum primary root " + "length or array have shape (nodes, 2)." + ) + # get primary_root_length + primary_root_length = ( + primary_length + if not isinstance(primary_length, np.ndarray) + else get_root_lengths(primary_length) + ) + + # check whether lateral_lengths is the lengths or lateral root nodex. + if not ( + isinstance(lateral_lengths, (float, np.float64)) # length with only one root + or lateral_lengths.ndim != 1 # lenthgs with more than one lateral roots + or lateral_lengths.ndim != 3 # lateral root nodes + ): + raise ValueError( + "Input lateral_lengths should be the lateral root lengths or array have " + "shape (instance, nodes, 2)." + ) + + # get lateral_root_length + if lateral_lengths.ndim != 3: # lateral root nodes + lateral_root_length = np.sum(get_root_lengths(lateral_lengths)) + elif lateral_lengths.ndim != 1: # lenthgs with more than one lateral roots + lateral_root_length = np.sum(lateral_lengths) + else: # length with only one lateral root + lateral_root_length = lateral_lengths + + # get network_length_lower + if isinstance(network_length_lower, (float, np.float64)): + network_length_lower = network_length_lower + elif ( + primary_length.ndim == 2 + and lateral_lengths.ndim == 3 + and network_length_lower.ndim == 3 + ): + network_length_lower = get_network_distribution( + primary_length, lateral_lengths, network_length_lower, fraction, monocots + ) else: + raise ValueError( + "Input network_length_lower should be a float value, otherwise " + "primary_length is maximimum length primary root in shape `(node, 2)` and " + "primary_length is lateral root in shape `(instance, nodes, 2)`." + ) + + # return Nan if lengths less than 0 + if primary_root_length + lateral_root_length < 0: return np.nan + + if monocots: + ratio = network_length_lower / primary_root_length + else: + ratio = network_length_lower / (primary_root_length + lateral_root_length) + return ratio diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 1beccc5..94d1d07 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -1,115 +1,51 @@ -"""Get points function.""" +"""Get traits related to the points.""" import numpy as np -from sleap_roots.bases import get_root_lengths -from sleap_roots.series import Series -from typing import List -def get_pt_ind(pts: np.ndarray, proximal: bool = True) -> np.ndarray: - """Find proximal/distal point index. +def get_all_pts_array( + primary_max_length_pts: np.ndarray, lateral_pts: np.ndarray, monocots: bool = False +) -> np.ndarray: + """Get all landmark points within a given frame as a flat array of coordinates. Args: - pts: Numpy array of points of shape (instances, point, 2). - proximal: Boolean value, where True is proximal (default), False is distal. + primary_max_length_pts: Points of the primary root with maximum length of shape + `(nodes, 2)`. + lateral_pts: Lateral root points of shape `(instances, nodes, 2)`. + monocots: If False (default), returns a combined array of primary and lateral + root points. If True, returns only lateral root points. Returns: - An array of shape (instances,) of proximal or distal point index. + A 2D array of shape (n_points, 2), containing the coordinates of all extracted + points. """ - pt_ind = [] - for i in range(pts.shape[0]): - ind = 1 if proximal else pts.shape[1] - 1 # set initial proximal/distal point - while np.isnan(pts[i, ind]).any(): - ind += 1 if proximal else -1 - if (ind == pts.shape[1] and proximal) or (ind == 0 and not proximal): - break - pt_ind.append(ind) - return pt_ind - + # Check if the input arrays have the right number of dimensions + if primary_max_length_pts.ndim != 2 or lateral_pts.ndim != 3: + raise ValueError( + "Input arrays should have the correct number of dimensions:" + "primary_max_length_pts should be 2-dimensional and lateral_pts should be" + "3-dimensional." + ) -def get_primary_pts(plant: Series, frame: int) -> np.ndarray: - """Get primary root points. + # Check if the last dimension of the input arrays has size 2 (representing x and y coordinates) + if primary_max_length_pts.shape[-1] != 2 or lateral_pts.shape[-1] != 2: + raise ValueError( + "The last dimension of the input arrays should have size 2, representing x" + "and y coordinates." + ) - Args: - plant: Series object representing a plant image series. - frame: frame index + # Flatten the arrays to 2D + primary_max_length_pts = primary_max_length_pts.reshape(-1, 2) + lateral_pts = lateral_pts.reshape(-1, 2) - Return: - An array of primary root points of shape (1, n_points, 2). - If more than one primary root is present, the longest will be used. - """ - # get the primary root points, if >1 primary roots, return the longest primary root - pts_pr = plant.get_primary_points(frame_idx=frame) - if len(pts_pr) == 0: - return pts_pr + # Combine points + if monocots: + pts_all_array = lateral_pts else: - max_length_idx = np.nanargmax(get_root_lengths(pts_pr)) - pts_pr = pts_pr[np.newaxis, max_length_idx] - return pts_pr - - -def get_lateral_pts(plant: Series, frame: int) -> np.ndarray: - """Get lateral root points. - - Args: - plant: Series object representing a plant image series. - frame: frame index - - Return: - An array of primary root points in shape (instance, point, 2) - """ - pts_lr = plant.get_lateral_points(frame_idx=frame) - return pts_lr - - -def get_all_pts(plant: Series, frame: int, monocots: bool = False) -> List[np.ndarray]: - """Get all points within a frame. + # Check if the data types of the arrays are compatible + if primary_max_length_pts.dtype != lateral_pts.dtype: + raise ValueError("Input arrays should have the same data type.") - Args: - plant: Series object representing a plant image series. - frame: frame index - rice: boolean value, where True is rice frame - monocots: If False (the default), returns primary and lateral points - combined. If True, only lateral root points will be returned. This is useful for - monocot species such as rice. - - Return: - A list of numpy arrays containing sets of points of shape - (n_instances, n_points, 2). - """ - # get primary and lateral root points - pts_pr = get_primary_pts(plant, frame).tolist() - pts_lr = get_lateral_pts(plant, frame).tolist() - - pts_all = pts_lr if monocots else pts_pr + pts_lr - - return pts_all - - -def get_all_pts_array(plant: Series, frame: int, monocots: bool = False) -> np.ndarray: - """Get all points within a frame as a flat array of coordinates. - - Args: - plant: Series object representing a plant image series. - frame: frame index - monocots: If False (the default), returns primary and lateral points - combined. If True, only lateral root points will be returned. This is useful for - monocot species such as rice. - - Return: - An array of all points (primary and optionally lateral) as an array of shape - (n_points, 2). - """ - # get primary and lateral root points - pts_pr = get_primary_pts(plant, frame) - pts_lr = get_lateral_pts(plant, frame) - - pts_all_array = ( - pts_lr.reshape(-1, 2) - if monocots - else np.concatenate( - (np.array(pts_pr).reshape(-1, 2), np.array(pts_lr).reshape(-1, 2)), axis=0 - ) - ) + pts_all_array = np.concatenate((primary_max_length_pts, lateral_pts), axis=0) return pts_all_array diff --git a/sleap_roots/scanline.py b/sleap_roots/scanline.py index 83a3b76..d450654 100644 --- a/sleap_roots/scanline.py +++ b/sleap_roots/scanline.py @@ -60,63 +60,37 @@ def count_scanline_intersections( return np.array(intersection) -def get_scanline_first_ind( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - depth: int = 1080, - width: int = 2048, - n_line: int = 50, - monocots: bool = False, -): +def get_scanline_first_ind(scanline_intersection_counts: np.ndarray): """Get the index of count_scanline_interaction for the first interaction. Args: - primary_pts: Numpy array of primary points of shape (instances, nodes, 2). - lateral_pts: Numpy array of lateral points of shape (instances, nodes, 2). - depth: the depth of cylinder, or number of rows of the image. - width: the width of cylinder, or number of columns of the image. - n_line: number of scan lines, np.nan for no interaction. - monocots: whether True: only lateral roots (e.g., rice), or False: dicots. + scanline_intersection_counts: An array with shape of `(#Nline,)` of intersection + numbers of each scan line. Return: Scalar of count_scanline_interaction index for the first interaction. """ - count_scanline_interaction = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots - ) - if np.where((count_scanline_interaction > 0))[0].shape[0] > 0: - scanline_first_ind = np.where((count_scanline_interaction > 0))[0][0] + # get the first scanline index using scanline_intersection_counts + if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0: + scanline_first_ind = np.where((scanline_intersection_counts > 0))[0][0] return scanline_first_ind else: return np.nan -def get_scanline_last_ind( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, - depth: int = 1080, - width: int = 2048, - n_line: int = 50, - monocots: bool = False, -): +def get_scanline_last_ind(scanline_intersection_counts: np.ndarray): """Get the index of count_scanline_interaction for the last interaction. Args: - primary_pts: Numpy array of primary points of shape (instances, nodes, 2). - lateral_pts: Numpy array of lateral points of shape (instances, nodes, 2). - depth: the depth of cylinder, or number of rows of the image. - width: the width of cylinder, or number of columns of the image. - n_line: number of scan lines, np.nan for no interaction. - monocots: whether True: only lateral roots (e.g., rice), or False: dicots. + scanline_intersection_counts: An array with shape of `(#Nline,)` of intersection + numbers of each scan line. Return: Scalar of count_scanline_interaction index for the last interaction. """ - count_scanline_interaction = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots - ) - if np.where((count_scanline_interaction > 0))[0].shape[0] > 0: - scanline_last_ind = np.where((count_scanline_interaction > 0))[0][-1] + # get the first scanline index using scanline_intersection_counts + if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0: + scanline_last_ind = np.where((scanline_intersection_counts > 0))[0][-1] return scanline_last_ind else: return np.nan diff --git a/sleap_roots/tips.py b/sleap_roots/tips.py index da91131..9a23ac8 100644 --- a/sleap_roots/tips.py +++ b/sleap_roots/tips.py @@ -19,32 +19,33 @@ def get_tips(pts): return tip_pts -def get_primary_depth(pts: np.ndarray) -> np.ndarray: - """Get primary root tip depth. - - Args: - pts: primary root landmarks as array of shape (1, point, 2) - - Returns: - Primary root tip depth (location in y-axis). - """ - # get the last point of primary root, if invisible, return nan - if pts[:, -1].any() == np.nan: - return np.nan - else: - return pts[:, -1, 1] - - def get_tip_xs(pts: np.ndarray) -> np.ndarray: """Get x coordinations of tip points. Args: - pts: root landmarks as array of shape (instance, point, 2) + pts: root landmarks as array of shape (instance, point, 2) or tips (instance, 2) Return: - An array of tips in x axis (instance,). + An array tip x-coordinates (instance,). """ - _tip_pts = get_tips(pts) + if pts.ndim not in (2, 3): + raise ValueError( + "Input array must be 2-dimensional (n_tips, 2) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + if pts.ndim == 3: + _tip_pts = get_tips( + pts + ) # Assuming get_tips returns an array of shape (instance, 2) + else: + _tip_pts = pts + + if _tip_pts.ndim != 2 or _tip_pts.shape[1] != 2: + raise ValueError( + "Array of tip points must be 2-dimensional with shape (instance, 2)." + ) + tip_xs = _tip_pts[:, 0] return tip_xs @@ -53,11 +54,28 @@ def get_tip_ys(pts: np.ndarray) -> np.ndarray: """Get y coordinations of tip points. Args: - pts: root landmarks as array of shape (instance, point, 2) + pts: root landmarks as array of shape (instance, point, 2) or tips (instance, 2) Return: - An array of tips in y axis (instance,) + An array tip y-coordinates (instance,). """ - _tip_pts = get_tips(pts) + if pts.ndim not in (2, 3): + raise ValueError( + "Input array must be 2-dimensional (n_tips, 2) or " + "3-dimensional (n_roots, n_nodes, 2)." + ) + + if pts.ndim == 3: + _tip_pts = get_tips( + pts + ) # Assuming get_tips returns an array of shape (instance, 2) + else: + _tip_pts = pts + + if _tip_pts.ndim != 2 or _tip_pts.shape[1] != 2: + raise ValueError( + "Array of tip points must be 2-dimensional with shape (instance, 2)." + ) + tip_ys = _tip_pts[:, 1] return tip_ys diff --git a/sleap_roots/trait_pipeline.py b/sleap_roots/trait_pipeline.py new file mode 100644 index 0000000..91b5460 --- /dev/null +++ b/sleap_roots/trait_pipeline.py @@ -0,0 +1,1266 @@ +"""Extract traits in a pipeline based on the trait graph.""" + +import numpy as np +import pandas as pd +import attrs +from typing import List, Dict, Tuple, Callable, Optional, Any +from fractions import Fraction +import networkx as nx +from pathlib import Path +from sleap_roots.angle import get_root_angle, get_node_ind +from sleap_roots.lengths import get_root_lengths, get_grav_index, get_max_length_pts +from sleap_roots.bases import ( + get_bases, + get_base_ct_density, + get_base_length, + get_base_length_ratio, + get_base_median_ratio, + get_base_tip_dist, + get_base_xs, + get_base_ys, + get_lateral_count, + get_root_pair_widths_projections, +) +from sleap_roots.tips import get_tips, get_tip_xs, get_tip_ys +from sleap_roots.convhull import ( + get_convhull, + get_chull_area, + get_chull_line_lengths, + get_chull_max_width, + get_chull_max_height, + get_chull_perimeter, + get_convhull_features, +) +from sleap_roots.ellipse import ( + fit_ellipse, + get_ellipse_a, + get_ellipse_b, + get_ellipse_ratio, +) +from sleap_roots.networklength import ( + get_bbox, + get_network_distribution_ratio, + get_network_distribution, + get_network_length, + get_network_solidity, + get_network_width_depth_ratio, +) +from sleap_roots.points import get_all_pts_array +from sleap_roots.scanline import ( + count_scanline_intersections, + get_scanline_first_ind, + get_scanline_last_ind, +) +from sleap_roots.series import Series, find_all_series +from sleap_roots.summary import get_summary +import warnings + + +SCALAR_TRAITS = ( + "primary_angle_proximal", + "primary_angle_distal", + "primary_length", + "primary_base_tip_dist", + "lateral_count", + "grav_index", + "base_length", + "base_length_ratio", + "primary_tip_pt_y", + "base_median_ratio", + "base_ct_density", + "chull_perimeter", + "chull_area", + "chull_max_width", + "chull_max_height", + "ellipse_a", + "ellipse_b", + "ellipse_ratio", + "network_width_depth_ratio", + "network_solidity", + "network_length_lower", + "network_distribution_ratio", + "scanline_first_ind", + "scanline_last_ind", +) + +NON_SCALAR_TRAITS = ( + "lateral_angles_proximal", + "lateral_angles_distal", + "lateral_lengths", + "root_widths", + "lateral_base_xs", + "lateral_base_ys", + "lateral_tip_xs", + "lateral_tip_ys", + "chull_line_lengths", + "scanline_intersection_counts", +) + + +warnings.filterwarnings( + "ignore", + message="invalid value encountered in intersection", + category=RuntimeWarning, + module="shapely", +) +warnings.filterwarnings( + "ignore", message="All-NaN slice encountered", category=RuntimeWarning +) +warnings.filterwarnings( + "ignore", message="All-NaN axis encountered", category=RuntimeWarning +) +warnings.filterwarnings( + "ignore", + message="Degrees of freedom <= 0 for slice.", + category=RuntimeWarning, + module="numpy", +) +warnings.filterwarnings( + "ignore", message="Mean of empty slice", category=RuntimeWarning +) +warnings.filterwarnings( + "ignore", + message="invalid value encountered in sqrt", + category=RuntimeWarning, + module="skimage", +) +warnings.filterwarnings( + "ignore", + message="invalid value encountered in double_scalars", + category=RuntimeWarning, +) +warnings.filterwarnings( + "ignore", + message="invalid value encountered in scalar divide", + category=RuntimeWarning, + module="ellipse", +) + + +@attrs.define +class TraitDef: + """Definition of how to compute a trait. + + Attributes: + name: Unique identifier for the trait. + fn: Function used to compute the trait's value. + input_traits: List of trait names that should be computed before the current + trait and are expected as input positional arguments to `fn`. + scalar: Indicates if the trait is scalar (has a dimension of 0 per frame). If + `True`, the trait is also listed in `SCALAR_TRAITS`. + include_in_csv: `True `indicates the trait should be included in downstream CSV + files. + kwargs: Additional keyword arguments to be passed to the `fn` function. These + arguments are not reused from previously computed traits. + description: String describing the trait for documentation purposes. + + Notes: + The `fn` specified will be called with a pattern like: + + ``` + trait_def = TraitDef( + name="my_trait", + fn=compute_my_trait, + input_traits=["input_trait_1", "input_trait_2"], + scalar=True, + include_in_csv=True, + kwargs={"kwarg1": True} + ) + traits[trait_def.name] = trait_def.fn( + *[traits[input_trait] for input_trait in trait_def.input_traits], + **trait_def.kwargs + ) + ``` + + For this example, the last line is equivalent to: + + ``` + traits["my_trait"] = trait_def.fn( + traits["input_trait_1"], traits["input_trait_2"], + kwarg1=True + ) + ``` + """ + + name: str + fn: Callable + input_traits: List[str] + scalar: bool + include_in_csv: bool + kwargs: Dict[str, Any] = attrs.field(factory=dict) + description: Optional[str] = None + + +def get_traits_value_frame( + primary_pts: np.ndarray, + lateral_pts: np.ndarray, + root_width_tolerance: float = 0.02, + n_line: int = 50, + network_fraction: float = 2 / 3, + monocots: bool = False, +) -> Dict: + """Get SLEAP traits per frame based on graph. + + Args: + primary_pts: primary points + lateral_pts: lateral points + root_width_tolerance: Difference in projection norm between right and left side. + n_line: number of scan lines, np.nan for no interaction. + network_fraction: length found in the lower fration value of the network. + monocots: Boolean value, where false is dicot (default), true is rice. + + Return: + A dictionary with all traits per frame. + """ + # Define the trait computations. + trait_definitions = [ + TraitDef( + name="primary_max_length_pts", + fn=get_max_length_pts, + input_traits=["primary_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Points of the primary root with maximum length.", + ), + TraitDef( + name="pts_all_array", + fn=get_all_pts_array, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Landmark points within a given frame as a flat array" + "of coordinates.", + ), + TraitDef( + name="root_widths", + fn=get_root_pair_widths_projections, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=True, + kwargs={"tolerance": root_width_tolerance, "monocots": monocots}, + description="Return estimation of root width using bases of lateral roots.", + ), + TraitDef( + name="lateral_count", + fn=get_lateral_count, + input_traits=["lateral_pts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Get the number of lateral roots.", + ), + TraitDef( + name="lateral_proximal_node_inds", + fn=get_node_ind, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"proximal": True}, + description="Get the indices of the proximal nodes of lateral roots.", + ), + TraitDef( + name="lateral_distal_node_inds", + fn=get_node_ind, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"proximal": False}, + description="Get the indices of the distal nodes of lateral roots.", + ), + TraitDef( + name="lateral_lengths", + fn=get_root_lengths, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of root lengths of shape `(instances,)`.", + ), + TraitDef( + name="lateral_base_pts", + fn=get_bases, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Array of lateral bases `(instances, (x, y))`.", + ), + TraitDef( + name="lateral_tip_pts", + fn=get_tips, + input_traits=["lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Array of lateral tips `(instances, (x, y))`.", + ), + TraitDef( + name="scanline_intersection_counts", + fn=count_scanline_intersections, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=True, + kwargs={ + "depth": 1080, + "width": 2048, + "n_line": n_line, + "monocots": monocots, + }, + description="Array of intersections of each scanline `(#Nline,)`.", + ), + TraitDef( + name="lateral_angles_distal", + fn=get_root_angle, + input_traits=["lateral_pts", "lateral_distal_node_inds"], + scalar=False, + include_in_csv=True, + kwargs={"proximal": False, "base_ind": 0}, + description="Array of lateral distal angles in degrees `(instances,)`.", + ), + TraitDef( + name="lateral_angles_proximal", + fn=get_root_angle, + input_traits=["lateral_pts", "lateral_proximal_node_inds"], + scalar=False, + include_in_csv=True, + kwargs={"proximal": True, "base_ind": 0}, + description="Array of lateral proximal angles in degrees `(instances,)`.", + ), + TraitDef( + name="network_solidity", + fn=get_network_solidity, + input_traits=["network_length", "chull_area"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the total network length divided by the network convex area.", + ), + TraitDef( + name="ellipse", + fn=fit_ellipse, + input_traits=["pts_all_array"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of (a, b, ratio) containing the semi-major axis length, semi-minor axis length, and the ratio of the major to minor lengths.", + ), + TraitDef( + name="bounding_box", + fn=get_bbox, + input_traits=["pts_all_array"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of four parameters in bounding box.", + ), + TraitDef( + name="convex_hull", + fn=get_convhull, + input_traits=["pts_all_array"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Convex hull of the points.", + ), + TraitDef( + name="primary_proximal_node_ind", + fn=get_node_ind, + input_traits=["primary_max_length_pts"], + scalar=True, + include_in_csv=False, + kwargs={"proximal": True}, + description="Get the indices of the proximal nodes of primary roots.", + ), + TraitDef( + name="primary_angle_proximal", + fn=get_root_angle, + input_traits=["primary_max_length_pts", "primary_proximal_node_ind"], + scalar=True, + include_in_csv=True, + kwargs={"proximal": True, "base_ind": 0}, + description="Array of primary proximal angles in degrees `(instances,)`.", + ), + TraitDef( + name="primary_distal_node_ind", + fn=get_node_ind, + input_traits=["primary_max_length_pts"], + scalar=True, + include_in_csv=False, + kwargs={"proximal": False}, + description="Get the indices of the distal nodes of primary roots.", + ), + TraitDef( + name="primary_angle_distal", + fn=get_root_angle, + input_traits=["primary_max_length_pts", "primary_distal_node_ind"], + scalar=True, + include_in_csv=True, + kwargs={"proximal": False, "base_ind": 0}, + description="Array of primary distal angles in degrees `(instances,)`.", + ), + TraitDef( + name="primary_length", + fn=get_root_lengths, + input_traits=["primary_max_length_pts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of primary root length.", + ), + TraitDef( + name="primary_base_pt", + fn=get_bases, + input_traits=["primary_max_length_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Array of primary bases `(, (x, y))`.", + ), + TraitDef( + name="primary_tip_pt", + fn=get_tips, + input_traits=["primary_max_length_pts"], + scalar=False, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Array of primary tips `(, (x, y))`.", + ), + TraitDef( + name="network_length_lower", + fn=get_network_distribution, + input_traits=["primary_max_length_pts", "lateral_pts", "bounding_box"], + scalar=True, + include_in_csv=True, + kwargs={"fraction": network_fraction, "monocots": monocots}, + description="Scalar of the root network length in the lower fraction of the plant.", + ), + TraitDef( + name="lateral_base_xs", + fn=get_base_xs, + input_traits=["lateral_base_pts"], + scalar=False, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Array of the x-coordinates of lateral bases `(instance,)`.", + ), + TraitDef( + name="lateral_base_ys", + fn=get_base_ys, + input_traits=["lateral_base_pts"], + scalar=False, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Array of the y-coordinates of lateral bases `(instance,)`.", + ), + TraitDef( + name="base_ct_density", + fn=get_base_ct_density, + input_traits=["primary_length", "lateral_base_pts"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Scalar of base count density.", + ), + TraitDef( + name="lateral_tip_xs", + fn=get_tip_xs, + input_traits=["lateral_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Array of the x-coordinates of lateral tips `(instance,)`.", + ), + TraitDef( + name="lateral_tip_ys", + fn=get_tip_ys, + input_traits=["lateral_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Array of the y-coordinates of lateral tips `(instance,)`.", + ), + TraitDef( + name="network_distribution_ratio", + fn=get_network_distribution_ratio, + input_traits=["primary_length", "lateral_lengths", "network_length_lower"], + scalar=True, + include_in_csv=True, + kwargs={"fraction": network_fraction, "monocots": monocots}, + description="Scalar of ratio of the root network length in the lower fraction of the plant over all root length.", + ), + TraitDef( + name="network_length", + fn=get_network_length, + input_traits=["primary_length", "lateral_lengths"], + scalar=True, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Scalar of all roots network length.", + ), + TraitDef( + name="primary_base_pt_y", + fn=get_base_ys, + input_traits=["primary_base_pt"], + scalar=True, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Scalar of the y-coordinates of primary base `(instance,)`.", + ), + TraitDef( + name="primary_tip_pt_y", + fn=get_tip_ys, + input_traits=["primary_base_pt"], + scalar=True, + include_in_csv=False, + kwargs={"monocots": monocots}, + description="Scalar of the y-coordinates of primary tip `(instance,)`.", + ), + TraitDef( + name="ellipse_a", + fn=get_ellipse_a, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of semi-major axis length.", + ), + TraitDef( + name="ellipse_b", + fn=get_ellipse_b, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of semi-minor axis length.", + ), + TraitDef( + name="network_width_depth_ratio", + fn=get_network_width_depth_ratio, + input_traits=["bounding_box"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of bounding box width to depth ratio of root network.", + ), + TraitDef( + name="chull_perimeter", + fn=get_chull_perimeter, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull perimeter.", + ), + TraitDef( + name="chull_area", + fn=get_chull_area, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull area.", + ), + TraitDef( + name="chull_max_width", + fn=get_chull_max_width, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull maximum width.", + ), + TraitDef( + name="chull_max_height", + fn=get_chull_max_height, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull maximum height.", + ), + TraitDef( + name="chull_line_lengths", + fn=get_chull_line_lengths, + input_traits=["convex_hull"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of line lengths connecting any two vertices on the convex hull.", + ), + TraitDef( + name="base_length", + fn=get_base_length, + input_traits=["lateral_base_ys"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the distance between the top and deepest base y-coordinates.", + ), + TraitDef( + name="base_median_ratio", + fn=get_base_median_ratio, + input_traits=["lateral_base_ys", "primary_tip_pt_y"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Scalar of base median ratio.", + ), + TraitDef( + name="grav_index", + fn=get_grav_index, + input_traits=["primary_length", "primary_base_tip_dist"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Scalar of primary root gravity index.", + ), + TraitDef( + name="base_length_ratio", + fn=get_base_length_ratio, + input_traits=["primary_length", "base_length"], + scalar=True, + include_in_csv=True, + kwargs={"monocots": monocots}, + description="Scalar of base length ratio.", + ), + TraitDef( + name="primary_base_tip_dist", + fn=get_base_tip_dist, + input_traits=["primary_base_pt_y", "primary_tip_pt_y"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of distances from primary base to tip.", + ), + TraitDef( + name="ellipse_ratio", + fn=get_ellipse_ratio, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of ratio of the minor to major lengths.", + ), + TraitDef( + name="scanline_last_ind", + fn=get_scanline_last_ind, + input_traits=["scanline_intersection_counts"], + scalar=True, + include_in_csv=True, + kwargs={ + "depth": 1080, + "width": 2048, + "n_line": n_line, + "monocots": monocots, + }, + description="Scalar of count_scanline_interaction index for the last interaction.", + ), + TraitDef( + name="scanline_first_ind", + fn=get_scanline_first_ind, + input_traits=["scanline_intersection_counts"], + scalar=True, + include_in_csv=True, + kwargs={ + "depth": 1080, + "width": 2048, + "n_line": n_line, + "monocots": monocots, + }, + description="Scalar of count_scanline_interaction index for the first interaction.", + ), + ] + + # Map trait names to their definitions. + trait_map = {trait_def.name: trait_def for trait_def in trait_definitions} + + # trait_map = { + # # get_bases(pts: np.ndarray,monocots) -> np.ndarray + # "primary_base_pt": (get_bases, ["primary_pts"], {"monocots": monocots}), + # # get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray + # "primary_angle_proximal": ( + # get_root_angle, + # ["primary_pts"], + # {"proximal": True, "base_ind": 0}, + # ), + # "primary_angle_distal": ( + # get_root_angle, + # ["primary_pts"], + # {"proximal": False, "base_ind": 0}, + # ), + # # get_root_lengths(pts: np.ndarray) -> np.ndarray + # "primary_length": (get_root_lengths, ["primary_pts"], {}), + # # get_tips(pts) + # "primary_tip_pt": (get_tips, ["primary_pts"], {}), + # # fit_ellipse(pts: np.ndarray) -> Tuple[float, float, float] + # "ellipse": (fit_ellipse, ["pts_all_array"], {}), + # # get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float] + # "bounding_box": (get_bbox, ["pts_all_array"], {}), + # # get_root_pair_widths_projections(lateral_pts, primary_pts, tolerance,monocots) + # "root_widths": ( + # get_root_pair_widths_projections, + # ["primary_pts", "lateral_pts"], + # {"root_width_tolerance": root_width_tolerance, "monocots": monocots}, + # ), + # # get_convhull_features(pts: Union[np.ndarray, ConvexHull]) -> Tuple[float, float, float, float] + # "convex_hull": (get_convhull_features, ["pts_all_array"], {}), + # # get_lateral_count(pts: np.ndarray) + # "lateral_count": (get_lateral_count, ["lateral_pts"], {}), + # # # get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray + # "lateral_angles_proximal": ( + # get_root_angle, + # ["lateral_pts"], + # {"proximal": True, "base_ind": 0}, + # ), + # "lateral_angles_distal": ( + # get_root_angle, + # ["lateral_pts"], + # {"proximal": False, "base_ind": 0}, + # ), + # # get_root_lengths(pts: np.ndarray) -> np.ndarray + # "lateral_lengths": (get_root_lengths, ["lateral_pts"], {}), + # # get_bases(pts: np.ndarray,monocots) -> np.ndarray + # "lateral_base_pts": (get_bases, ["lateral_pts"], {"monocots": monocots}), + # # get_tips(pts) + # "lateral_tip_pts": (get_tips, ["lateral_pts"], {}), + # # get_base_ys(pts: np.ndarray) -> np.ndarray + # # or just based on primary_base_pt, but the primary_base_pt trait must generate before + # # "primary_base_pt_y": (get_pt_ys, [data["primary_base_pt"]]), + # "primary_base_pt_y": (get_base_ys, ["primary_pts"], {}), + # # get_base_ct_density(primary_pts, lateral_pts) + # "base_ct_density": ( + # get_base_ct_density, + # [ + # "primary_pts", + # "lateral_pts", + # ], + # {"monocots": monocots}, + # ), + # # get_network_solidity(primary_pts: np.ndarray, lateral_pts: np.ndarray, pts_all_array: np.ndarray, monocots: bool = False,) -> float + # "network_solidity": ( + # get_network_solidity, + # ["primary_pts", "lateral_pts", "chull_area"], + # {"monocots": monocots}, + # ), + # # get_network_distribution_ratio(primary_pts: np.ndarray,lateral_pts: np.ndarray,pts_all_array: np.ndarray,fraction: float = 2 / 3, monocots: bool = False) -> float: + # "network_distribution_ratio": ( + # get_network_distribution_ratio, + # [ + # "primary_length", + # "lateral_lengths", + # "network_length_lower", + # ], + # {"network_fraction": network_fraction, "monocots": monocots}, + # ), + # # get_network_distribution(primary_pts: np.ndarray,lateral_pts: np.ndarray,pts_all_array: np.ndarray,fraction: float = 2 / 3, monocots: bool = False) -> float: + # "network_length_lower": ( + # get_network_distribution, + # ["primary_pts", "lateral_pts", "bounding_box"], + # {"network_fraction": network_fraction, "monocots": monocots}, + # ), + # # get_network_width_depth_ratio(pts: np.ndarray) -> float + # "network_width_depth_ratio": ( + # get_network_width_depth_ratio, + # ["bounding_box"], + # {}, + # ), + # # get_tip_ys(pts: np.ndarray) -> np.ndarray + # "primary_tip_pt_y": (get_tip_ys, ["primary_tip_pt"], {}), + # # get_ellipse_a(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) + # "ellipse_a": (get_ellipse_a, ["ellipse"], {}), + # # get_ellipse_b(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) + # "ellipse_b": (get_ellipse_b, ["ellipse"], {}), + # # get_ellipse_ratio(pts_all_array: Union[np.ndarray, Tuple[float, float, float]]) + # "ellipse_ratio": (get_ellipse_ratio, ["ellipse"], {}), + # # get_chull_perimeter(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) + # "chull_perimeter": (get_chull_perimeter, ["convex_hull"], {}), + # # get_chull_area(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) + # "chull_area": (get_chull_area, ["convex_hull"], {}), + # # get_chull_max_width(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) + # "chull_max_width": (get_chull_max_width, ["convex_hull"], {}), + # # get_chull_max_height(pts: Union[np.ndarray, ConvexHull, Tuple[float, float, float, float]]) + # "chull_max_height": (get_chull_max_height, ["convex_hull"], {}), + # # get_chull_line_lengths(pts: Union[np.ndarray, ConvexHull]) -> np.ndarray + # "chull_line_lengths": (get_chull_line_lengths, ["convex_hull"], {}), + # # scanline_intersection_counts: + # "scanline_intersection_counts": ( + # count_scanline_intersections, + # [primary_pts, lateral_pts], + # {"depth": 1080, "width": 2048, "n_line": 50, "monocots": monocots}, + # ), + # # get_base_xs(pts: np.ndarray) -> np.ndarray + # "lateral_base_xs": (get_base_xs, ["lateral_base_pts"], {"monocots": monocots}), + # # get_base_ys(pts: np.ndarray) -> np.ndarray + # "lateral_base_ys": (get_base_ys, ["lateral_base_pts"], {"monocots": monocots}), + # # get_tip_xs(pts: np.ndarray) -> np.ndarray + # "lateral_tip_xs": (get_tip_xs, ["lateral_tip_pts"], {"monocots": monocots}), + # # get_tip_ys(pts: np.ndarray) -> np.ndarray + # "lateral_tip_ys": (get_tip_ys, ["lateral_tip_pts"], {"monocots": monocots}), + # # get_base_tip_dist(pts: np.ndarray) -> np.ndarray + # "primary_base_tip_dist": ( + # get_base_tip_dist, + # { + # "base_pts": "primary_base_pt", + # "tip_pts": "primary_tip_pt", + # "pts": "primary_pts", + # }, + # ), + # # get_base_median_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray) + # "base_median_ratio": ( + # get_base_median_ratio, + # ["lateral_base_ys", "primary_tip_pt_y"], + # {"monocots": monocots}, + # ), + # # get_scanline_last_ind(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False) + # "scanline_last_ind": ( + # get_scanline_last_ind, + # [primary_pts, lateral_pts, 1080, 2048, n_line, monocots], + # ), + # # get_scanline_first_ind(primary_pts: np.ndarray,lateral_pts: np.ndarray,depth: int = 1080, width: int = 2048, n_line: int = 50, monocots: bool = False) + # "scanline_first_ind": ( + # get_scanline_first_ind, + # [primary_pts, lateral_pts, 1080, 2048, n_line, monocots], + # ), + # # get_base_length(pts: np.ndarray) + # "base_length": (get_base_length, [lateral_pts, monocots]), + # # get_grav_index(pts: np.ndarray) + # "grav_index": (get_grav_index, [primary_pts]), + # # get_base_length_ratio(primary_pts: np.ndarray, lateral_pts: np.ndarray) + # "base_length_ratio": (get_base_length_ratio, [primary_pts, lateral_pts]), + # } + + # Initialize edges with precomputed top-level traits. + edges = [("pts", "primary_pts"), ("pts", "lateral_pts")] + + # Infer edges from trait map. + # for output_trait, (_, input_traits, _) in trait_map.items(): + # for input_trait in input_traits: + # edges.append((input_trait, output_trait)) + for trait_def in trait_definitions: + for input_trait in trait_def.input_traits: + edges.append((input_trait, trait_def.name)) + + # Compute breadth-first ordering. + G = nx.DiGraph() + G.add_edges_from(edges) + trait_computation_order = [ + dst for (src, dst) in list(nx.bfs_tree(G, "pts").edges())[2:] + ] + + # Initialize traits container with initial points. + traits = {"primary_pts": primary_pts, "lateral_pts": lateral_pts} + + # Compute traits! + for trait_name in trait_computation_order: + # fn, input_traits, kwargs = trait_map[trait_name] + trait_def = trait_map[trait_name] + + traits[trait_name] = trait_def.fn( + *[traits[input_trait] for input_trait in trait_def.input_traits], + **trait_def.kwargs, + ) + + return traits + + +def get_traits_value_plant( + h5, + monocots: bool = False, + primary_name: str = "primary_multi_day", + lateral_name: str = "lateral_3_nodes", + root_width_tolerance: float = 0.02, + n_line: int = 50, + network_fraction: float = 2 / 3, + write_csv: bool = False, + csv_suffix: str = ".traits.csv", +) -> Tuple[Dict, pd.DataFrame, str]: + """Get detailed SLEAP traits for every frame of a plant, based on the graph. + + Args: + h5: The h5 file representing the plant image series. + monocots: A boolean value indicating whether the plant is a monocot (True) + or a dicot (False) (default). + primary_name: Name of the primary root predictions. The predictions file is + expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. + lateral_name: Name of the lateral root predictions. The predictions file is + expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. + root_width_tolerance: The difference in the projection norm between + the right and left side of the root. + n_line: The number of scan lines. Use np.nan for no interaction. + network_fraction: The length found in the lower fraction value of the network. + write_csv: A boolean value. If True, it writes per plant detailed + CSVs with traits for every instance on every frame. + csv_suffix: If write_csv=True, the CSV file will be saved with the + h5 path + csv_suffix. + + Returns: + A tuple containing a dictionary and a DataFrame with all traits per plant, + and the plant name. The Dataframe has root traits per instance and frame + where each row corresponds to a frame in the H5 file. The plant_name is + given by the h5 file. + """ + plant = Series.load(h5, primary_name=primary_name, lateral_name=lateral_name) + plant_name = plant.series_name + # get number of frames per plant + n_frame = len(plant) + + data_plant = [] + # get traits for each frames in a row + for frame in range(n_frame): + primary, lateral = plant[frame] + + gt_instances_pr = primary.user_instances + primary.unused_predictions + gt_instances_lr = lateral.user_instances + lateral.unused_predictions + + if len(gt_instances_lr) == 0: + lateral_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + else: + lateral_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) + + if len(gt_instances_pr) == 0: + primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + else: + primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) + + data = get_traits_value_frame( + primary_pts, + lateral_pts, + root_width_tolerance, + n_line, + network_fraction, + monocots, + ) + + data["plant_name"] = plant_name + data["frame_idx"] = frame + data_plant.append(data) + data_plant_df = pd.DataFrame(data_plant) + + # reorganize the column position + column_names = data_plant_df.columns.tolist() + column_names = [column_names[-2]] + [column_names[-1]] + column_names[:-2] + data_plant_df = data_plant_df[column_names] + + # convert the data in scalar column to the value without [] + columns_to_convert = data_plant_df.columns[ + data_plant_df.apply( + lambda x: all( + isinstance(val, np.ndarray) and val.shape == (1,) for val in x + ) + ) + ] + data_plant_df[columns_to_convert] = data_plant_df[columns_to_convert].apply( + lambda x: x.apply(lambda val: val[0]) + ) + + if write_csv: + csv_name = Path(h5).with_suffix(f"{csv_suffix}") + data_plant_df.to_csv(csv_name, index=False) + return data_plant, data_plant_df, plant_name + + +def get_traits_value_plant_summary( + h5, + monocots: bool = False, + primary_name: str = "longest_3do_6nodes", + lateral_name: str = "main_3do_6nodes", + root_width_tolerance: float = 0.02, + n_line: int = 50, + network_fraction: float = 2 / 3, + write_csv: bool = False, + csv_suffix: str = ".traits.csv", + write_summary_csv: bool = False, + summary_csv_suffix: str = ".summary_traits.csv", +) -> pd.DataFrame: + """Get summary statistics of SLEAP traits per plant based on graph. + + Args: + h5: The h5 file representing the plant image series. + monocots: A boolean value indicating whether the plant is a monocot (True) + or a dicot (False) (default). + primary_name: Name of the primary root predictions. The predictions file is + expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. + lateral_name: Name of the lateral root predictions. The predictions file is + expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. + root_width_tolerance: The difference in the projection norm between + the right and left side of the root. + n_line: The number of scan lines. Use np.nan for no interaction. + network_fraction: The length found in the lower fraction value of the network. + write_csv: A boolean value. If True, it writes per plant detailed + CSVs with traits for every instance on every frame. + csv_suffix: If write_csv=True, the CSV file will be saved with the name + h5 path + csv_suffix. + write_summary_csv: Boolean value, where true is write summarized csv file. + summary_csv_suffix: If write_summary_csv=True, the CSV file with the summary + statistics per plant will be saved with the name + h5 path + summary_csv_suffix. + + Return: + A DataFrame with summary statistics of all traits per plant. + """ + data_plant, data_plant_df, plant_name = get_traits_value_plant( + h5, + monocots, + primary_name, + lateral_name, + root_width_tolerance, + n_line, + network_fraction, + write_csv, + csv_suffix, + ) + + # get summarized non-scalar traits per frame + data_plant_frame_summary = [] + data_plant_frame_summary_non_scalar = {} + + for i in range(len(NON_SCALAR_TRAITS)): + trait = data_plant_df[NON_SCALAR_TRAITS[i]] + + if not trait.isna().all(): + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fmin" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else (np.nanmin(x) if len(x) > 0 else np.nan) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fmax" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else (np.nanmax(x) if len(x) > 0 else np.nan) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fmean" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else (np.nanmean(x) if len(x) > 0 else np.nan) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fmedian" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else (np.nanmedian(x) if len(x) > 0 else np.nan) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fstd" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else (np.nanstd(x) if len(x) > 0 else np.nan) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc5" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else (np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 5)) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc25" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else ( + np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 25) + ) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc75" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else ( + np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 75) + ) + ) + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc95" + ] = trait.apply( + lambda x: x + if isinstance(x, (np.floating, float, np.integer, int)) + else ( + np.nan if np.isnan(x).all() else np.percentile(x[~pd.isna(x)], 95) + ) + ) + else: + data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fmin"] = np.nan + data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fmax"] = np.nan + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fmean" + ] = np.nan + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fmedian" + ] = np.nan + data_plant_frame_summary_non_scalar[NON_SCALAR_TRAITS[i] + "_fstd"] = np.nan + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc5" + ] = np.nan + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc25" + ] = np.nan + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc75" + ] = np.nan + data_plant_frame_summary_non_scalar[ + NON_SCALAR_TRAITS[i] + "_fprc95" + ] = np.nan + + # get summarized scalar traits per plant + column_names = data_plant_df.columns.tolist() + data_plant_frame_summary = {} + for i in range(len(SCALAR_TRAITS)): + if SCALAR_TRAITS[i] in column_names: + trait = data_plant_df[SCALAR_TRAITS[i]] + if trait.shape[0] > 0: + if not ( + isinstance(trait[0], (np.floating, float)) + or isinstance(trait[0], (np.integer, int)) + ): + values = np.array([element[0] for element in trait]) + trait = values + trait = trait.astype(float) + trait = np.reshape(trait, (len(trait), 1)) + ( + trait_min, + trait_max, + trait_mean, + trait_median, + trait_std, + trait_prc5, + trait_prc25, + trait_prc75, + trait_prc95, + ) = get_summary(trait) + + data_plant_frame_summary[SCALAR_TRAITS[i] + "_min"] = trait_min + data_plant_frame_summary[SCALAR_TRAITS[i] + "_max"] = trait_max + data_plant_frame_summary[SCALAR_TRAITS[i] + "_mean"] = trait_mean + data_plant_frame_summary[SCALAR_TRAITS[i] + "_median"] = trait_median + data_plant_frame_summary[SCALAR_TRAITS[i] + "_std"] = trait_std + data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc5"] = trait_prc5 + data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc25"] = trait_prc25 + data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc75"] = trait_prc75 + data_plant_frame_summary[SCALAR_TRAITS[i] + "_prc95"] = trait_prc95 + + # append the summarized non-scalar traits per plant + data_plant_frame_summary_key = list(data_plant_frame_summary_non_scalar.keys()) + for j in range(len(data_plant_frame_summary_non_scalar)): + trait = data_plant_frame_summary_non_scalar[data_plant_frame_summary_key[j]] + ( + trait_min, + trait_max, + trait_mean, + trait_median, + trait_std, + trait_prc5, + trait_prc25, + trait_prc75, + trait_prc95, + ) = get_summary(trait) + + data_plant_frame_summary[data_plant_frame_summary_key[j] + "_min"] = trait_min + data_plant_frame_summary[data_plant_frame_summary_key[j] + "_max"] = trait_max + data_plant_frame_summary[data_plant_frame_summary_key[j] + "_mean"] = trait_mean + data_plant_frame_summary[ + data_plant_frame_summary_key[j] + "_median" + ] = trait_median + data_plant_frame_summary[data_plant_frame_summary_key[j] + "_std"] = trait_std + data_plant_frame_summary[data_plant_frame_summary_key[j] + "_prc5"] = trait_prc5 + data_plant_frame_summary[ + data_plant_frame_summary_key[j] + "_prc25" + ] = trait_prc25 + data_plant_frame_summary[ + data_plant_frame_summary_key[j] + "_prc75" + ] = trait_prc75 + data_plant_frame_summary[ + data_plant_frame_summary_key[j] + "_prc95" + ] = trait_prc95 + data_plant_frame_summary["plant_name"] = [plant_name] + data_plant_frame_summary_df = pd.DataFrame(data_plant_frame_summary) + + # reorganize the column position + column_names = data_plant_frame_summary_df.columns.tolist() + column_names = [column_names[-1]] + column_names[:-1] + data_plant_frame_summary_df = data_plant_frame_summary_df[column_names] + + if write_summary_csv: + summary_csv_name = Path(h5).with_suffix(f"{summary_csv_suffix}") + data_plant_frame_summary_df.to_csv(summary_csv_name, index=False) + return data_plant_frame_summary_df + + +def get_all_plants_traits( + data_folders: List[str], + primary_name: str, + lateral_name: str, + root_width_tolerance: float = 0.02, + n_line: int = 50, + network_fraction: Fraction = Fraction(2, 3), + write_per_plant_details: bool = False, + per_plant_details_csv_suffix: str = ".traits.csv", + write_per_plant_summary: bool = False, + per_plant_summary_csv_suffix: str = ".summary_traits.csv", + monocots: bool = False, + all_plants_csv_name: str = "all_plants_traits.csv", +) -> pd.DataFrame: + """Get a DataFrame with summary traits from all plants in the given data folders. + + Args: + h5: The h5 file representing the plant image series. + monocots: A boolean value indicating whether the plant is a monocot (True) + or a dicot (False) (default). + primary_name: Name of the primary root predictions. The predictions file is + expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. + lateral_name: Name of the lateral root predictions. The predictions file is + expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. + root_width_tolerance: The difference in the projection norm between + the right and left side of the root. + n_line: The number of scan lines. Use np.nan for no interaction. + network_fraction: The length found in the lower fraction value of the network. + write_per_plant_details: A boolean value. If True, it writes per plant detailed + CSVs with traits for every instance. + per_plant_details_csv_suffix: If write_csv=True, the CSV file will be saved + with the name h5 path + csv_suffix. + write_per_plant_summary: A boolean value. If True, it writes per plant summary + CSVs. + per_plant_summary_csv_suffix: If write_summary_csv=True, the CSV file with the + summary statistics per plant will be saved with the name + h5 path + summary_csv_suffix. + all_plants_csv_name: The name of the output CSV file containing all plants' + summary traits. + + Returns: + A pandas DataFrame with summary root traits for all plants in the data folders. + Each row is a sample. + """ + h5_series = find_all_series(data_folders) + + all_traits = [] + for h5 in h5_series: + plant_traits = get_traits_value_plant_summary( + h5, + monocots=monocots, + primary_name=primary_name, + lateral_name=lateral_name, + root_width_tolerance=root_width_tolerance, + n_line=n_line, + network_fraction=network_fraction, + write_csv=write_per_plant_details, + csv_suffix=per_plant_details_csv_suffix, + write_summary_csv=write_per_plant_summary, + summary_csv_suffix=per_plant_summary_csv_suffix, + ) + plant_traits["path"] = h5 + all_traits.append(plant_traits) + + all_traits_df = pd.concat(all_traits, ignore_index=True) + + all_traits_df.to_csv(all_plants_csv_name, index=False) + return all_traits_df diff --git a/sleap_roots/traitsgraph.py b/sleap_roots/traitsgraph.py deleted file mode 100644 index cf2d9d3..0000000 --- a/sleap_roots/traitsgraph.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Create traits graph and get destinations.""" - -import networkx as nx - - -def get_traits_graph(): - """Get traits graph using networkx. - - Args: - None - - Return: - Destination nodes. - """ - G = nx.DiGraph() - edge_list = [ - ("pts", "primary_pts"), - ("pts", "lateral_pts"), - ("primary_pts", "primary_base_pt"), - ("primary_pts", "primary_angle_proximal"), - ("primary_pts", "primary_angle_distal"), - ("primary_pts", "primary_length"), - ("primary_base_pt", "primary_base_pt_y"), - ("primary_pts", "primary_tip_pt"), - ("primary_base_pt_y", "primary_base_tip_dist"), - ("primary_tip_pt_y", "primary_base_tip_dist"), - ("primary_tip_pt_y", "primary_depth"), - ("lateral_pts", "lateral_count"), - ("primary_tip_pt", "primary_tip_pt_y"), - ("primary_length_max", "grav_index"), - ("primary_base_tip_dist", "grav_index"), - ("lateral_base_ys", "base_length"), - ("base_length", "base_length_ratio"), - ("primary_length_max", "base_length_ratio"), - ("lateral_base_ys", "base_median_ratio"), - ("primary_tip_pt_y", "base_median_ratio"), - ("lateral_base_pts", "base_ct_density"), - ("primary_length", "base_ct_density"), - ("convex_hull", "chull_perimeter"), - ("convex_hull", "chull_area"), - ("convex_hull", "chull_max_width"), - ("convex_hull", "chull_max_height"), - ("primary_pts", "ellipse"), - ("lateral_pts", "ellipse"), - ("ellipse", "ellipse_a"), - ("ellipse", "ellipse_b"), - ("ellipse_a", "ellipse_ratio"), - ("ellipse_b", "ellipse_ratio"), - ("chull_area", "network_solidity"), - ("lateral_lengths", "network_solidity"), - ("primary_length", "network_solidity"), - ("lateral_pts", "bounding_box"), - ("primary_pts", "bounding_box"), - ("bounding_box", "network_width_depth_ratio"), - ("lateral_lengths", "network_distribution_ratio"), - ("primary_length", "network_distribution_ratio"), - ("primary_length", "network_length_lower"), - ("lateral_lengths", "network_length_lower"), - ("bounding_box", "network_distribution_ratio"), - ("bounding_box", "network_length_lower"), - ("scanline_intersection_counts", "scanline_last_ind"), - ("scanline_intersection_counts", "scanline_first_ind"), - ("lateral_pts", "lateral_angles_proximal"), - ("lateral_pts", "lateral_angles_distal"), - ("lateral_pts", "lateral_lengths"), - ("primary_pts", "stem_widths"), - ("lateral_pts", "lateral_base_pts"), - ("lateral_base_pts", "stem_widths"), - ("lateral_base_pts", "lateral_base_xs"), - ("lateral_base_pts", "lateral_base_ys"), - ("lateral_pts", "lateral_tip_pts"), - ("lateral_tip_pts", "lateral_tip_xs"), - ("lateral_tip_pts", "lateral_tip_ys"), - ("primary_pts", "convex_hull"), - ("lateral_pts", "convex_hull"), - ("convex_hull", "chull_line_lengths"), - ("primary_pts", "scanline_intersection_counts"), - ("lateral_pts", "scanline_intersection_counts"), - ] - - G.add_edges_from(edge_list) - dts = [dst for (src, dst) in list(nx.bfs_tree(G, "pts").edges())[2:]] - return dts diff --git a/tests/test_angle.py b/tests/test_angle.py index 7abf8d3..666e667 100644 --- a/tests/test_angle.py +++ b/tests/test_angle.py @@ -123,42 +123,56 @@ def test_get_node_ind(canola_h5): pts = primary.numpy() proximal = True node_ind = get_node_ind(pts, proximal) - np.testing.assert_array_equal(node_ind, [1]) + np.testing.assert_array_equal(node_ind, 1) # test get_node_ind function using root that without second node def test_get_node_ind_nan2(pts_nan2): proximal = True node_ind = get_node_ind(pts_nan2, proximal) - np.testing.assert_array_equal(node_ind, [2]) + np.testing.assert_array_equal(node_ind, 2) # test get_node_ind function using root that without second and third nodes def test_get_node_ind_nan3(pts_nan3): proximal = True node_ind = get_node_ind(pts_nan3, proximal) - np.testing.assert_array_equal(node_ind, [3]) + np.testing.assert_array_equal(node_ind, 0) # test get_node_ind function using two roots/instances def test_get_node_ind_nan32(pts_nan32): proximal = True node_ind = get_node_ind(pts_nan32, proximal) - np.testing.assert_array_equal(node_ind, [3, 2]) + np.testing.assert_array_equal(node_ind, [0, 2]) # test get_node_ind function using root that without last node def test_get_node_ind_nan6(pts_nan6): proximal = False node_ind = get_node_ind(pts_nan6, proximal) - np.testing.assert_array_equal(node_ind, [4]) + np.testing.assert_array_equal(node_ind, 4) # test get_node_ind function using root with all nan node def test_get_node_ind_nanall(pts_nanall): proximal = False node_ind = get_node_ind(pts_nanall, proximal) - np.testing.assert_array_equal(node_ind, [0]) + np.testing.assert_array_equal(node_ind, np.nan) + + +# test get_node_ind function using root with pts_nan32_5node +def test_get_node_ind_5node(pts_nan32_5node): + proximal = False + node_ind = get_node_ind(pts_nan32_5node, proximal) + np.testing.assert_array_equal(node_ind, [4, 4]) + + +# test get_node_ind function (proximal) using root with pts_nan32_5node +def test_get_node_ind_5node_proximal(pts_nan32_5node): + proximal = True + node_ind = get_node_ind(pts_nan32_5node, proximal) + np.testing.assert_array_equal(node_ind, [0, 2]) # test canola get_root_angle function (base node to distal node angle) @@ -169,8 +183,8 @@ def test_get_root_angle_distal(canola_h5): primary, lateral = series[0] pts = primary.numpy() proximal = False - angs = get_root_angle(pts, proximal) - assert angs.shape == (1,) + node_ind = get_node_ind(pts, proximal) + angs = get_root_angle(pts, node_ind, proximal) assert pts.shape == (1, 6, 2) np.testing.assert_almost_equal(angs, 7.7511306, decimal=3) @@ -183,7 +197,8 @@ def test_get_root_angle_proximal_rice(rice_h5): primary, lateral = series[0] pts = primary.numpy() proximal = True - angs = get_root_angle(pts, proximal) + node_ind = get_node_ind(pts, proximal) + angs = get_root_angle(pts, node_ind, proximal) assert angs.shape == (2,) assert pts.shape == (2, 6, 2) np.testing.assert_almost_equal(angs, [17.3180819, 3.2692877], decimal=3) @@ -192,7 +207,8 @@ def test_get_root_angle_proximal_rice(rice_h5): # test get_root_angle function using two roots/instances (base node to proximal node angle) def test_get_root_angle_proximal(pts_nan32): proximal = True - angs = get_root_angle(pts_nan32, proximal) + node_ind = get_node_ind(pts_nan32, proximal) + angs = get_root_angle(pts_nan32, node_ind, proximal) assert angs.shape == (2,) np.testing.assert_almost_equal(angs, [np.nan, 1.7291381], decimal=3) @@ -200,14 +216,15 @@ def test_get_root_angle_proximal(pts_nan32): # test get_root_angle function using two roots/instances (base node to proximal node angle) def test_get_root_angle_proximal_5node(pts_nan32_5node): proximal = True - angs = get_root_angle(pts_nan32_5node, proximal) + node_ind = get_node_ind(pts_nan32_5node, proximal) + angs = get_root_angle(pts_nan32_5node, node_ind, proximal) assert angs.shape == (2,) np.testing.assert_almost_equal(angs, [np.nan, 2.3339111], decimal=3) # test get_root_angle function using root/instance with all nan value -def test_get_root_angle_proximal_5node(pts_nanall): +def test_get_root_angle_proximal_allnan(pts_nanall): proximal = True - angs = get_root_angle(pts_nanall, proximal) - assert angs.shape == (1,) + node_ind = get_node_ind(pts_nanall, proximal) + angs = get_root_angle(pts_nanall, node_ind, proximal) np.testing.assert_almost_equal(angs, np.nan, decimal=3) diff --git a/tests/test_bases.py b/tests/test_bases.py index 614188e..56c5b25 100644 --- a/tests/test_bases.py +++ b/tests/test_bases.py @@ -2,18 +2,13 @@ get_bases, get_base_ct_density, get_base_tip_dist, - get_grav_index, get_lateral_count, - get_root_lengths, - get_root_lengths_max, get_base_xs, get_base_ys, get_base_length, get_base_length_ratio, - get_primary_depth, get_root_pair_widths_projections, ) -from sleap_roots.points import get_lateral_pts from sleap_roots import Series import numpy as np import pytest @@ -180,97 +175,32 @@ def test_bases_no_roots(pts_no_roots): # test get_base_tip_dist with standard points def test_get_base_tip_dist_standard(pts_standard): - distance = get_base_tip_dist(pts_standard) + distance = get_base_tip_dist(pts=pts_standard) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [2.82842712, 2.82842712], decimal=7) # test get_base_tip_dist with roots without bases def test_get_base_tip_dist_no_bases(pts_no_bases): - distance = get_base_tip_dist(pts_no_bases) + distance = get_base_tip_dist(pts=pts_no_bases) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [np.nan, np.nan], decimal=7) # test get_base_tip_dist with roots with one base def test_get_base_tip_dist_one_base(pts_one_base): - distance = get_base_tip_dist(pts_one_base) + distance = get_base_tip_dist(pts=pts_one_base) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [2.82842712, np.nan], decimal=7) # test get_base_tip_dist with no roots def test_get_base_tip_dist_no_roots(pts_no_roots): - distance = get_base_tip_dist(pts_no_roots) + distance = get_base_tip_dist(pts=pts_no_roots) assert distance.shape == (2,) np.testing.assert_almost_equal(distance, [np.nan, np.nan], decimal=7) -# test get_grav_index function -def test_get_grav_index(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - grav_index = get_grav_index(pts) - np.testing.assert_almost_equal(grav_index, 0.08898137324716636) - - -def test_get_root_lengths(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - assert pts.shape == (1, 6, 2) - - root_lengths = get_root_lengths(pts) - assert root_lengths.shape == (1,) - np.testing.assert_array_almost_equal(root_lengths, [971.050417]) - - pts = lateral.numpy() - assert pts.shape == (5, 3, 2) - - root_lengths = get_root_lengths(pts) - assert root_lengths.shape == (5,) - np.testing.assert_array_almost_equal( - root_lengths, [20.129579, 62.782368, 80.268003, 34.925591, 3.89724] - ) - - -def test_get_root_lengths_no_roots(pts_no_bases): - root_lengths = get_root_lengths(pts_no_bases) - assert root_lengths.shape == (2,) - np.testing.assert_array_almost_equal(root_lengths, np.array([np.nan, np.nan])) - - -def test_get_root_lengths_one_point(pts_one_base): - root_lengths = get_root_lengths(pts_one_base) - assert root_lengths.shape == (2,) - np.testing.assert_array_almost_equal( - root_lengths, np.array([2.82842712475, np.nan]) - ) - - -# test get_root_lengths_max function with lengths_normal -def test_get_root_lengths_max_normal(lengths_normal): - max_length = get_root_lengths_max(lengths_normal) - np.testing.assert_array_almost_equal(max_length, 329.4) - - -# test get_root_lengths_max function with lengths_with_nan -def test_get_root_lengths_max_with_nan(lengths_with_nan): - max_length = get_root_lengths_max(lengths_with_nan) - np.testing.assert_array_almost_equal(max_length, 329.4) - - -# test get_root_lengths_max function with lengths_all_nan -def test_get_root_lengths_max_all_nan(lengths_all_nan): - max_length = get_root_lengths_max(lengths_all_nan) - np.testing.assert_array_almost_equal(max_length, np.nan) - - # test get_lateral_count function with canola def test_get_lateral_count(canola_h5): series = Series.load( @@ -294,6 +224,17 @@ def test_get_base_xs_canola(canola_h5): np.testing.assert_almost_equal(base_xs[1], 1112.5506591796875, decimal=3) +# test get_base_xs with rice +def test_get_base_xs_rice(rice_h5): + monocots = True + plant = Series.load( + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + pts_lr = get_lateral_pts(plant=plant, frame=0) + base_xs = get_base_xs(pts_lr, monocots) + assert np.isnan(base_xs) + + # test get_base_xs with pts_standard def test_get_base_xs_standard(pts_standard): base_xs = get_base_xs(pts_standard) @@ -321,6 +262,17 @@ def test_get_base_ys_canola(canola_h5): np.testing.assert_almost_equal(base_ys[1], 228.0966796875, decimal=3) +# test get_base_ys with rice +def test_get_base_ys_rice(rice_h5): + monocots = True + plant = Series.load( + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + pts_lr = get_lateral_pts(plant=plant, frame=0) + base_ys = get_base_ys(pts_lr, monocots) + assert np.isnan(base_ys) + + # test get_base_ys with pts_standard def test_get_base_ys_standard(pts_standard): base_ys = get_base_ys(pts_standard) @@ -346,6 +298,16 @@ def test_get_base_length_canola(canola_h5): np.testing.assert_almost_equal(base_length, 83.69914245605469, decimal=3) +# test get_base_length with rice +def test_get_base_length_rice(rice_h5): + plant = Series.load( + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + pts_lr = get_lateral_pts(plant=plant, frame=0) + base_length = get_base_length(pts_lr, monocots=True) + assert np.isnan(base_length) + + # test get_base_length with pts_standard def test_get_base_length_standard(pts_standard): base_length = get_base_length(pts_standard) @@ -377,21 +339,16 @@ def test_get_base_ct_density_canola(canola_h5): np.testing.assert_almost_equal(base_ct_density, 0.004119, decimal=5) -# test get_primary_depth function with defined primary_pts -def test_get_primary_depth(primary_pts): - primary_depth = get_primary_depth(primary_pts) - np.testing.assert_almost_equal(primary_depth, 808.12585449, decimal=3) - - -# test get_primary_depth function with canola -def test_get_primary_depth(canola_h5): +# test get_base_ct_density function with rice example +def test_get_base_ct_density_rice(rice_h5): + monocots = True series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) - primary, lateral = series[0] - primary_pts = primary.numpy() - primary_depth = get_primary_depth(primary_pts) - np.testing.assert_almost_equal(primary_depth, 1020.9813842773438, decimal=3) + primary_pts = get_primary_pts(plant=series, frame=0) + lateral_pts = get_lateral_pts(plant=series, frame=0) + base_ct_density = get_base_ct_density(primary_pts, lateral_pts, monocots) + assert np.isnan(base_ct_density) # test get_base_length_ratio with canola @@ -406,7 +363,7 @@ def test_get_base_length_ratio(canola_h5): np.testing.assert_almost_equal(base_length_ratio, 0.086, decimal=3) -def test_stem_width(canola_h5): +def test_root_width(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) @@ -416,5 +373,5 @@ def test_stem_width(canola_h5): assert primary_pts.shape == (1, 6, 2) assert lateral_pts.shape == (5, 3, 2) - stem_widths = get_root_pair_widths_projections(lateral_pts, primary_pts, 0.02) - np.testing.assert_array_almost_equal(stem_widths, [31.603239]) + root_widths = get_root_pair_widths_projections(lateral_pts, primary_pts, 0.02) + assert np.isnan(root_widths) diff --git a/tests/test_ellipse.py b/tests/test_ellipse.py index 564b1ec..6a8ef6d 100644 --- a/tests/test_ellipse.py +++ b/tests/test_ellipse.py @@ -6,6 +6,7 @@ get_ellipse_b, get_ellipse_ratio, ) +from sleap_roots.lengths import get_max_length_pts from sleap_roots.points import get_all_pts_array @@ -36,7 +37,13 @@ def test_get_ellipse_a(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) ellipse_a = get_ellipse_a(pts_all_array) np.testing.assert_almost_equal(ellipse_a, 398.1275346610801, decimal=3) @@ -45,7 +52,13 @@ def test_get_ellipse_b(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) ellipse_b = get_ellipse_b(pts_all_array) np.testing.assert_almost_equal(ellipse_b, 115.03734180292595, decimal=3) @@ -54,6 +67,28 @@ def test_get_ellipse_ratio(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) ellipse_ratio = get_ellipse_ratio(pts_all_array) np.testing.assert_almost_equal(ellipse_ratio, 3.460854783511295, decimal=3) + + +def test_get_ellipse_ratio_ellipse(canola_h5): + plant = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary, lateral = plant[0] + primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) + lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=False + ) + ellipse = fit_ellipse(pts_all_array) + ellipse_ratio = get_ellipse_ratio(ellipse) + np.testing.assert_almost_equal(ellipse_ratio, 3.460854783511295, decimal=3) diff --git a/tests/test_graphpipeline.py b/tests/test_graphpipeline.py index befc797..8e302dd 100644 --- a/tests/test_graphpipeline.py +++ b/tests/test_graphpipeline.py @@ -1,4 +1,4 @@ -from sleap_roots.graphpipeline import ( +from sleap_roots.trait_pipeline import ( get_traits_value_frame, get_traits_value_plant, get_traits_value_plant_summary, diff --git a/tests/test_lengths.py b/tests/test_lengths.py new file mode 100644 index 0000000..b1d3942 --- /dev/null +++ b/tests/test_lengths.py @@ -0,0 +1,207 @@ +from sleap_roots.lengths import ( + get_grav_index, + get_root_lengths, + get_root_lengths_max, +) +from sleap_roots import Series +import numpy as np +import pytest + + +@pytest.fixture +def pts_standard(): + return np.array( + [ + [ + [1, 2], + [3, 4], + ], + [ + [5, 6], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def pts_no_bases(): + return np.array( + [ + [ + [np.nan, np.nan], + [3, 4], + ], + [ + [np.nan, np.nan], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def pts_one_base(): + return np.array( + [ + [ + [1, 2], + [3, 4], + ], + [ + [np.nan, np.nan], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def pts_no_roots(): + return np.array( + [ + [ + [np.nan, np.nan], + [np.nan, np.nan], + ], + [ + [np.nan, np.nan], + [np.nan, np.nan], + ], + ] + ) + + +@pytest.fixture +def pts_not_contiguous(): + return np.array( + [ + [ + [1, 2], + [np.nan, np.nan], + [3, 4], + ], + [ + [5, 6], + [np.nan, np.nan], + [7, 8], + ], + ] + ) + + +@pytest.fixture +def primary_pts(): + return np.array( + [ + [ + [852.17755127, 216.95648193], + [850.17755127, 472.83520508], + [844.45300293, 472.83520508], + [837.03405762, 588.5123291], + [828.87963867, 692.72009277], + [816.71142578, 808.12585449], + ] + ] + ) + + +@pytest.fixture +def lateral_pts(): + return np.array( + [ + [ + [852.17755127, 216.95648193], + [np.nan, np.nan], + [837.03405762, 588.5123291], + [828.87963867, 692.72009277], + [816.71142578, 808.12585449], + ], + [ + [852.17755127, 216.95648193], + [844.45300293, 472.83520508], + [837.03405762, 588.5123291], + [828.87963867, 692.72009277], + [816.71142578, 808.12585449], + ], + ] + ) + + +@pytest.fixture +def lengths_normal(): + return np.array([145, 234, 329.4]) + + +@pytest.fixture +def lengths_with_nan(): + return np.array([145, 234, 329.4, np.nan]) + + +@pytest.fixture +def lengths_all_nan(): + return np.array([np.nan, np.nan, np.nan]) + + +# test get_grav_index function +def test_get_grav_index(canola_h5): + series = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary_pts = get_primary_pts(plant=series, frame=0) + grav_index = get_grav_index(pts=primary_pts) + np.testing.assert_almost_equal(grav_index, 0.08898137324716636) + + +def test_get_root_lengths(canola_h5): + series = Series.load( + canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + ) + primary, lateral = series[0] + pts = primary.numpy() + assert pts.shape == (1, 6, 2) + + root_lengths = get_root_lengths(pts) + assert root_lengths.shape == (1,) + np.testing.assert_array_almost_equal(root_lengths, [971.050417]) + + pts = lateral.numpy() + assert pts.shape == (5, 3, 2) + + root_lengths = get_root_lengths(pts) + assert root_lengths.shape == (5,) + np.testing.assert_array_almost_equal( + root_lengths, [20.129579, 62.782368, 80.268003, 34.925591, 3.89724] + ) + + +def test_get_root_lengths_no_roots(pts_no_bases): + root_lengths = get_root_lengths(pts_no_bases) + assert root_lengths.shape == (2,) + np.testing.assert_array_almost_equal(root_lengths, np.array([np.nan, np.nan])) + + +def test_get_root_lengths_one_point(pts_one_base): + root_lengths = get_root_lengths(pts_one_base) + assert root_lengths.shape == (2,) + np.testing.assert_array_almost_equal( + root_lengths, np.array([2.82842712475, np.nan]) + ) + + +# test get_root_lengths_max function with lengths_normal +def test_get_root_lengths_max_normal(lengths_normal): + max_length = get_root_lengths_max(lengths_normal) + np.testing.assert_array_almost_equal(max_length, 329.4) + + +# test get_root_lengths_max function with lengths_with_nan +def test_get_root_lengths_max_with_nan(lengths_with_nan): + max_length = get_root_lengths_max(lengths_with_nan) + np.testing.assert_array_almost_equal(max_length, 329.4) + + +# test get_root_lengths_max function with lengths_all_nan +def test_get_root_lengths_max_all_nan(lengths_all_nan): + max_length = get_root_lengths_max(lengths_all_nan) + np.testing.assert_array_almost_equal(max_length, np.nan) diff --git a/tests/test_networklength.py b/tests/test_networklength.py index 6735d87..e0660d4 100644 --- a/tests/test_networklength.py +++ b/tests/test_networklength.py @@ -1,6 +1,8 @@ import pytest import numpy as np from sleap_roots import Series +from sleap_roots.convhull import get_chull_area, get_convhull +from sleap_roots.lengths import get_max_length_pts, get_root_lengths from sleap_roots.networklength import get_bbox from sleap_roots.networklength import get_network_distribution from sleap_roots.networklength import get_network_distribution_ratio @@ -81,86 +83,129 @@ def test_get_network_width_depth_ratio_nan(pts_nan3): np.testing.assert_almost_equal(ratio, np.nan, decimal=7) -def test_get_network_solidity(canola_h5): +def test_get_network_length(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=False) + lateral_lengths = get_root_lengths(lateral_pts) monocots = False - ratio = get_network_solidity(primary_pts, lateral_pts, pts_all_array, monocots) - np.testing.assert_almost_equal(ratio, 0.012578941125511587, decimal=7) + length = get_network_length(primary_length, lateral_lengths, monocots) + np.testing.assert_almost_equal(length, 1173.0531992388217, decimal=7) -def test_get_network_solidity_rice(rice_h5): +def test_get_network_length_rice(rice_h5): series = Series.load( rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=True) + lateral_lengths = get_root_lengths(lateral_pts) monocots = True - ratio = get_network_solidity(primary_pts, lateral_pts, pts_all_array, monocots) - np.testing.assert_almost_equal(ratio, 0.17930631242462894, decimal=7) + length = get_network_length(primary_length, lateral_lengths, monocots) + np.testing.assert_almost_equal(length, 798.5726441151357, decimal=7) -def test_get_network_distribution(canola_h5): +def test_get_network_solidity(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=False) - fraction = 2 / 3 + lateral_lengths = get_root_lengths(lateral_pts) monocots = False - root_length = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + network_length = get_network_length(primary_length, lateral_lengths, monocots) + + # get chull_area + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots ) - np.testing.assert_almost_equal(root_length, 589.4322131363684, decimal=7) + convex_hull = get_convhull(pts_all_array) + chull_area = get_chull_area(convex_hull) + + ratio = get_network_solidity(network_length, chull_area) + np.testing.assert_almost_equal(ratio, 0.012578941125511587, decimal=7) -def test_get_network_distribution_rice(rice_h5): +def test_get_network_solidity_rice(rice_h5): series = Series.load( rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral_lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=True) - fraction = 2 / 3 + lateral_lengths = get_root_lengths(lateral_pts) monocots = True - root_length = get_network_distribution( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + network_length = get_network_length(primary_length, lateral_lengths, monocots) + + # get chull_area + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots ) - np.testing.assert_almost_equal(root_length, 475.89810040497025, decimal=7) + convex_hull = get_convhull(pts_all_array) + chull_area = get_chull_area(convex_hull) + + ratio = get_network_solidity(network_length, chull_area) + np.testing.assert_almost_equal(ratio, 0.17930631242462894, decimal=7) -def test_get_network_length(canola_h5): +def test_get_network_distribution(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) lateral_pts = lateral.numpy() monocots = False - length = get_network_length(primary_pts, lateral_pts, monocots) - np.testing.assert_almost_equal(length, 1173.0531992388217, decimal=7) + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + fraction = 2 / 3 + monocots = False + root_length = get_network_distribution( + primary_max_length_pts, lateral_pts, pts_all_array, fraction, monocots + ) + np.testing.assert_almost_equal(root_length, 589.4322131363684, decimal=7) -def test_get_network_length_rice(rice_h5): +def test_get_network_distribution_rice(rice_h5): series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" ) primary, lateral = series[0] primary_pts = primary.numpy() + primary_max_length_pts = get_max_length_pts(primary_pts) lateral_pts = lateral.numpy() + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + fraction = 2 / 3 monocots = True - length = get_network_length(primary_pts, lateral_pts, monocots) - np.testing.assert_almost_equal(length, 798.5726441151357, decimal=7) + root_length = get_network_distribution( + primary_max_length_pts, lateral_pts, pts_all_array, fraction, monocots + ) + np.testing.assert_almost_equal(root_length, 475.89810040497025, decimal=7) def test_get_network_distribution_ratio(canola_h5): @@ -169,12 +214,29 @@ def test_get_network_distribution_ratio(canola_h5): ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=False) + lateral_lengths = get_root_lengths(lateral_pts) + # get pts_all_array + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + # get network_length_lower + network_length_lower = get_network_distribution( + primary_max_length_pts, lateral_pts, pts_all_array + ) + monocots = False fraction = 2 / 3 monocots = False ratio = get_network_distribution_ratio( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + primary_length, + lateral_lengths, + network_length_lower, + fraction, + monocots, ) np.testing.assert_almost_equal(ratio, 0.5024769665338648, decimal=7) @@ -185,11 +247,29 @@ def test_get_network_distribution_ratio_rice(rice_h5): ) primary, lateral = series[0] primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length = get_root_lengths(primary_max_length_pts) + # get lateral lengths lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array(plant=series, frame=0, monocots=True) - fraction = 2 / 3 + lateral_lengths = get_root_lengths(lateral_pts) + # get pts_all_array + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + # get network_length_lower + network_length_lower = get_network_distribution( + primary_max_length_pts, lateral_pts, pts_all_array + ) monocots = True + fraction = 2 / 3 + monocots = False ratio = get_network_distribution_ratio( - primary_pts, lateral_pts, pts_all_array, fraction, monocots + primary_length, + lateral_lengths, + network_length_lower, + fraction, + monocots, ) + np.testing.assert_almost_equal(ratio, 0.5959358912579489, decimal=7) diff --git a/tests/test_points.py b/tests/test_points.py index 4944933..3fdf368 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -1,205 +1,41 @@ -import numpy as np -import pytest from sleap_roots import Series +from sleap_roots.lengths import get_max_length_pts, get_root_lengths from sleap_roots.points import ( - get_pt_ind, - get_primary_pts, - get_lateral_pts, - get_all_pts, get_all_pts_array, ) -@pytest.fixture -def pts_nan2(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, 472.83520508], - [844.45300293, 472.83520508], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ] - ] - ) - - -@pytest.fixture -def pts_nan3(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ] - ] - ) - - -@pytest.fixture -def pts_nan6(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [np.nan, np.nan], - ] - ] - ) - - -@pytest.fixture -def pts_nanall(): - return np.array( - [ - [ - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - ] - ] - ) - - -@pytest.fixture -def pts_nan32(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [844.45300293, 472.83520508], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - ] - ) - - -@pytest.fixture -def pts_nan32_5node(): - return np.array( - [ - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [np.nan, np.nan], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - [ - [852.17755127, 216.95648193], - [np.nan, np.nan], - [837.03405762, 588.5123291], - [828.87963867, 692.72009277], - [816.71142578, 808.12585449], - ], - ] - ) - - -# test get_pt_ind function -def test_get_pt_ind(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - proximal = True - node_ind = get_pt_ind(pts, proximal) - np.testing.assert_array_equal(node_ind, [1]) - - -# test get_pt_ind function using root that without second node -def test_get_pt_ind_nan2(pts_nan2): - proximal = True - node_ind = get_pt_ind(pts_nan2, proximal) - np.testing.assert_array_equal(node_ind, [2]) - - -# test get_pt_ind function using root that without second and third nodes -def test_get_pt_ind_nan3(pts_nan3): - proximal = True - node_ind = get_pt_ind(pts_nan3, proximal) - np.testing.assert_array_equal(node_ind, [3]) - - -# test get_pt_ind function using two roots/instances -def test_get_pt_ind_nan32(pts_nan32): - proximal = True - node_ind = get_pt_ind(pts_nan32, proximal) - np.testing.assert_array_equal(node_ind, [3, 2]) - - -# test get_pt_ind function using root that without last node -def test_get_pt_ind_nan6(pts_nan6): - proximal = False - node_ind = get_pt_ind(pts_nan6, proximal) - np.testing.assert_array_equal(node_ind, [4]) - - -# test get_pt_ind function using root with all nan node -def test_get_pt_ind_nanall(pts_nanall): - proximal = False - node_ind = get_pt_ind(pts_nanall, proximal) - np.testing.assert_array_equal(node_ind, [0]) - - -# test get_primary_pts function -def test_get_primary_pts(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts_pr = get_primary_pts(plant=plant, frame=0) - assert pts_pr.shape == (1, 6, 2) - - -# test get_lateral_pts function -def test_get_lateral_pts(canola_h5): +# test get_all_pts_array function +def test_get_all_pts_array(canola_h5): plant = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) - pts_lr = get_lateral_pts(plant=plant, frame=0) - assert pts_lr.shape == (5, 3, 2) - - -# test get_all_pts function -def test_get_all_pts(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" + primary, lateral = plant[0] + primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + # get lateral_lengths + lateral_pts = lateral.numpy() + monocots = False + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots ) - pts_all = get_all_pts(plant=plant, frame=0, monocots=False) - assert len(pts_all) == 6 - assert len(pts_all[0]) == 6 - assert len(pts_all[1]) == 3 + assert pts_all_array.shape[0] == 21 # test get_all_pts_array function -def test_get_all_pts_array(canola_h5): +def test_get_all_pts_array_rice(rice_h5): plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - pts_all_array = get_all_pts_array(plant=plant, frame=0, monocots=False) - assert pts_all_array.shape[0] == 21 + rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + ) + primary, lateral = plant[0] + primary_pts = primary.numpy() + # get primary length + primary_max_length_pts = get_max_length_pts(primary_pts) + # get lateral_lengths + lateral_pts = lateral.numpy() + monocots = True + pts_all_array = get_all_pts_array( + primary_max_length_pts, lateral_pts, monocots=monocots + ) + assert pts_all_array.shape[0] == 12 diff --git a/tests/test_scanline.py b/tests/test_scanline.py index fd5cfa2..22e1c5a 100644 --- a/tests/test_scanline.py +++ b/tests/test_scanline.py @@ -101,9 +101,10 @@ def test_get_scanline_first_ind(canola_h5): width = 2048 n_line = 50 monocots = False - scanline_first_ind = get_scanline_first_ind( + scanline_intersection_counts = count_scanline_intersections( primary_pts, lateral_pts, depth, width, n_line, monocots ) + scanline_first_ind = get_scanline_first_ind(scanline_intersection_counts) np.testing.assert_equal(scanline_first_ind, 6) @@ -119,7 +120,8 @@ def test_get_scanline_last_ind(canola_h5): width = 2048 n_line = 50 monocots = True - scanline_last_ind = get_scanline_last_ind( + scanline_intersection_counts = count_scanline_intersections( primary_pts, lateral_pts, depth, width, n_line, monocots ) + scanline_last_ind = get_scanline_last_ind(scanline_intersection_counts) np.testing.assert_equal(scanline_last_ind, 15) diff --git a/tests/test_tips.py b/tests/test_tips.py index 19c299a..0c218ab 100644 --- a/tests/test_tips.py +++ b/tests/test_tips.py @@ -1,5 +1,4 @@ -from sleap_roots.points import get_lateral_pts -from sleap_roots.tips import get_tips, get_primary_depth, get_tip_xs, get_tip_ys +from sleap_roots.tips import get_tips, get_tip_xs, get_tip_ys from sleap_roots import Series import numpy as np import pytest @@ -81,18 +80,6 @@ def test_tips_one_tip(pts_one_tip): np.testing.assert_array_equal(tips, [[3, 4], [np.nan, np.nan]]) -# test get_primary_depth with standard points -def test_get_primary_depth_standard(pt_standard): - primary_depth = get_primary_depth(pt_standard) - np.testing.assert_array_almost_equal(primary_depth, 4) - - -# test get_primary_depth with nan tip points -def test_get_primary_depth_nan(pt_nan_tip): - primary_depth = get_primary_depth(pt_nan_tip) - np.testing.assert_array_almost_equal(primary_depth, np.nan) - - # test get_tip_xs with canola def test_get_tip_xs_canola(canola_h5): plant = Series.load( diff --git a/tests/test_traitsgraph.py b/tests/test_traitsgraph.py deleted file mode 100644 index b4faa53..0000000 --- a/tests/test_traitsgraph.py +++ /dev/null @@ -1,7 +0,0 @@ -from sleap_roots.traitsgraph import get_traits_graph - - -def test_get_traits_graph(): - dts = get_traits_graph() - assert len(dts) == 43 - assert dts[0] == "primary_base_pt"