Skip to content

Commit

Permalink
Refactor pipelines into classes (#51)
Browse files Browse the repository at this point in the history
* Reorganizing the trait_map and modify ellipse, network functions.

* Update ellipse argument default setting.

* Reorganize ellipse and network functions by reducing arguments.

* Map convex hull traits

* Change primary to lateral when monocots is True

* Get tips x and y coordinates uses network map

* Change "stem" to "root"

* Fix tip y map

* Change root width back to take lateral_pts

* Changing the order of positional arguments to match others (primary is first)

* Fix plotting for using sleap-io API

* Make positional arguments consistent

* Refactor `get_base_xs` to use graph

* Map `scanline_intersection_counts` and use keyword arguments

* Refactor `get_base_ys`, `get_base_length`, and `base_ct_density` to use graph and add comments. Delete duplicate `primary_depth`.

* Clean up dependencies. Fix tip_ys.

* Refactor `get_root_lengths_max` for use with graph

* Refactor `get_base_tip_dist` to make base and tip pts or all points optional arguments

* Delete `primary_depth`

* Delete traitsgraph

* Delete traitsgraph dependencies

* Refactor base-related traits to use graph optionally

* Delete traits graph dependency

* Use `get_primary_pts` from series class

* Delete `get_primary_depth` tests

* Fix trait map for base traits

* Delete test for traitsgraph.py

* Standardize trait definition in trait map

* Change "graph" to "trait"

* Fix docstrings in `get_bases`

* Use `TraitDef` class

* Fix docstrings

* Add argument to class `TraitDef` whether to include in csv or if scalar

* Change `attr` to `attrs`

* Add `lengths.py` for length-related traits.

* Add `primary_max_length_pts` to trait definitions

* Add `pts_all_array` and `convex_hull` trait definitions

* Fix docstring

* Import base-related trait to `lengths.py`

* Make sure arrays of points are 2-dimensional

* Streamline point-related functions

* Vectorize `get_node_ind`

* Add trait definitions until `lateral_lengths`

* Delete unnecessary code

* Use node_ind for `get_root_angle` function.

* Modify base functions by assuming primary_pts as the primary_length_max.

* Modify argument pts as Optional in `get_base_tip_dist` function

* Modify argument pts as Optional in `get_grav_index` function

* Draft the trait_definitions using the defined TraitDef class.

* Uppercase the `get_root_angle` function arg description.

* Add test_lengths module for lengths-related functions.

* Remove lengths-related functions from test_bases.

* Set pts as Optional argument for `get_grav_index` function.

* Change the module name for importing lengths-related functions.

* Remove importing the points functions, only keep `get_all_pts_array`.

* Test ellipse-related functions.

* Redo the function `get_node_ind`.

* Test function `get_node_ind`.

* Angle function reset node_ind to array if only one value.

* Angle function return nan if all Nan node, return value if single array.

* Test angle functions.

* Add network_width_depth_ratio in trait_definitions.

* Reorganize arguments of `get_network_distribution_ratio` function.

* Add `network_length` trait before calculating `network_solidity`.

* Update `primary_root_length` function with calculated lengths.

* Update `get_network_solidity` function with calculated network_length.

* Test network-related functions.

* Test points function (`get_all_pts_array`).

* Update and test scanline functions using calculated scanline counts.

* Refactor `get_root_pair_widths_projections` to take in `primary_max_length_pts`

* Cleanup trait map

* Fix tests for base-related traits

* Add test for `get_max_lengths_pts`

* Refactored `get_base_ct_density` to take `primary_length_max` and `lateral_base_pts` as arguments

* Fixed multi-line strings

* Refactor base-related traits

* Refactor base-related traits and tests

* Test root-length-related traits

* Test tip-related traits

* Refactor convex-hull-related traits

* Test convex-hull functions

* Lint

* Lint

* Lint

* Lint

* Lint

* Fix kwargs involving `get_tips` in trait map

* Fix input for pipeline tests

* Refactor network related functions

* Test pipeline

* Refactor scanline function

* Start refactoring pipeline into classes

* Finish refactoring trait pipelines into classes

* Runtime fixes

* More refactoring to minimize redundant code across pipeline types

* Rename module and fix tests

* Add missing renamed modules

* Fix summary tests

* Fix Series to load video directly to bypass path resolution issues

* Lint

* Lint

---------

Co-authored-by: Lin Wang <[email protected]>
Co-authored-by: Elizabeth Berrigan <[email protected]>
  • Loading branch information
3 people authored Aug 17, 2023
1 parent c565235 commit 18958b2
Show file tree
Hide file tree
Showing 27 changed files with 2,570 additions and 2,166 deletions.
5 changes: 2 additions & 3 deletions sleap_roots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import sleap_roots.scanline
import sleap_roots.series
import sleap_roots.summary
import sleap_roots.traitsgraph
import sleap_roots.graphpipeline
from sleap_roots.graphpipeline import get_all_plants_traits
import sleap_roots.trait_pipelines
from sleap_roots.trait_pipelines import DicotPipeline, TraitDef
from sleap_roots.series import Series

# Define package version.
Expand Down
103 changes: 82 additions & 21 deletions sleap_roots/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,112 @@
import math


def get_node_ind(pts: np.ndarray, proximal=True) -> np.ndarray:
"""Find nproximal/distal node index.
def get_node_ind(pts: np.ndarray, proximal: bool = True) -> np.ndarray:
"""Find proximal/distal node index.
Args:
pts: Numpy array of points of shape (instances, nodes, 2).
pts: Numpy array of points of shape (instances, nodes, 2) or (nodes, 2).
proximal: Boolean value, where true is proximal (default), false is distal.
Returns:
An array of shape (instances,) of proximal or distal node index.
"""
node_ind = []
for i in range(pts.shape[0]):
ind = 1 if proximal else pts.shape[1] - 1 # set initial proximal/distal node
while np.isnan(pts[i, ind]).any():
ind += 1 if proximal else -1
if (ind == pts.shape[1] and proximal) or (ind == 0 and not proximal):
break
node_ind.append(ind)
# Check if pts is a numpy array
if not isinstance(pts, np.ndarray):
raise TypeError("Input pts should be a numpy array.")

# Check if pts has 2 or 3 dimensions
if pts.ndim not in [2, 3]:
raise ValueError("Input pts should have 2 or 3 dimensions.")

# Check if the last dimension of pts has size 2
if pts.shape[-1] != 2:
raise ValueError(
"The last dimension of the input pts should have size 2,"
"representing x and y coordinates."
)

# Check if pts is 2D, if so, reshape to 3D
if pts.ndim == 2:
pts = pts[np.newaxis, ...]

# Identify where NaN values exist
nan_mask = np.isnan(pts).any(axis=-1)

# If only NaN values, return NaN
if nan_mask.all():
return np.nan

if proximal:
# For proximal, we want the first non-NaN node in the first half root
# get the first half nan mask (exclude the base node)
node_proximal = nan_mask[:, 1 : int((nan_mask.shape[1] + 1) / 2)]
# get the nearest non-Nan node index
node_ind = np.argmax(~node_proximal, axis=-1)
# if there is no non-Nan node, set value of 99
node_ind[node_proximal.all(axis=1)] = 99
node_ind = node_ind + 1 # adjust indices by adding one (base node)
else:
# For distal, we want the last non-NaN node in the last half root
# get the last half nan mask
node_distal = nan_mask[:, int(nan_mask.shape[1] / 2) :]
# get the farest non-Nan node
node_ind = (node_distal[:, ::-1] == False).argmax(axis=1)
node_ind[node_distal.all(axis=1)] = -95 # set value if no non-Nan node
node_ind = pts.shape[1] - node_ind - 1 # adjust indices by reversing

# reset indices of 0 (base node) if no non-Nan node
node_ind[node_ind == 100] = 0

# If pts was originally 2D, return a scalar instead of a single-element array
if pts.shape[0] == 1:
return node_ind[0]

# If only one root, return a scalar instead of a single-element array
if node_ind.shape[0] == 1:
return node_ind[0]

return node_ind


def get_root_angle(pts: np.ndarray, proximal=True, base_ind=0) -> np.ndarray:
def get_root_angle(
pts: np.ndarray, node_ind: np.ndarray, proximal: bool = True, base_ind=0
) -> np.ndarray:
"""Find angles for each root.
Args:
pts: Numpy array of points of shape (instances, nodes, 2).
node_ind: Primary or lateral root node index.
proximal: Boolean value, where true is proximal (default), false is distal.
base_ind: Index of base node in the skeleton (default: 0).
Returns:
An array of shape (instances,) of angles in degrees, modulo 360.
"""
node_ind = get_node_ind(pts, proximal) # get proximal or distal node index
# if node_ind is a single int value, make it as array to keep consistent
if not isinstance(node_ind, np.ndarray):
node_ind = [node_ind]

if np.isnan(node_ind).all():
return np.nan

if pts.ndim == 2:
pts = np.expand_dims(pts, axis=0)

angs_root = []
for i in range(len(node_ind)):
# filter out the cases if all nan nodes in last/first half part
# to calculate proximal/distal angle
if (node_ind[i] < math.ceil(pts.shape[1] / 2) and proximal) or (
node_ind[i] >= math.floor(pts.shape[1] / 2) and not (proximal)
):
# if the node_ind is 0, do NOT calculate angs
if node_ind[i] == 0:
angs = np.nan
else:
xy = pts[i, node_ind[i], :] - pts[i, base_ind, :] # center on base node
# calculate the angle and convert to the start with gravity direction
ang = np.arctan2(-xy[1], xy[0]) * 180 / np.pi
angs = abs(ang + 90) if ang < 90 else abs(-(360 - 90 - ang))
else:
angs = np.nan
angs_root.append(angs)
return np.array(angs_root)
angs_root = np.array(angs_root)

# If only one root, return a scalar instead of a single-element array
if angs_root.shape[0] == 1:
return angs_root[0]
return angs_root
Loading

0 comments on commit 18958b2

Please sign in to comment.