Skip to content

Commit

Permalink
Return NaNs as default in root widths
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Mar 30, 2024
1 parent eb99831 commit 9bc5e18
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions sleap_roots/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,23 +245,22 @@ def get_root_widths(
Returns:
- If `return_inds` is False (default):
Returns an array of distances between the bases of matched roots. An empty
array is returned if no matching indices are found.
Returns an array of distances between the bases of matched roots. If no
matched indices are found, NaN is returned.
- If `return_inds` is True:
Returns a tuple containing the following four elements:
- matched_dists: Distances between the bases of matched roots. An empty
array is returned if no matched indices
are found.
- matched_dists: Distances between the bases of matched roots. If no
matched indices are found, NaN is returned.
- matched_indices: List of tuples, each containing the indices
of matched roots on the left and right sides. A list containing a
tuple of NaNs is returned if no matched indices are found.
- left_bases_final: (n, 2) array containing the (x, y)
coordinates of the left bases of the matched roots. An empty array
of shape (0, 2) is returned if no matched indices are found.
coordinates of the left bases of the matched roots. An array of
NaNs is returned if no matched indices are found.
- right_bases_final: (n, 2) array containing the (x, y)
coordinates of the right bases of the matched roots. An empty array
of shape (0, 2) is returned if no matched indices are found.
coordinates of the right bases of the matched roots. An array of
NaNs is returned if no matched indices are found.
"""
# Validate tolerance
if tolerance <= 0:
Expand All @@ -275,11 +274,11 @@ def get_root_widths(
if primary_max_length_pts.shape[1] != 2 or lateral_pts.shape[2] != 2:
raise ValueError("The last dimension should contain x and y coordinates")

# Initialize default return values with shapes that match the expected output
default_dists = np.full((0,), np.nan) # Array filled with NaN values
default_indices = [(np.nan, np.nan)]
default_left_bases = np.full((0, 2), np.nan) # 2D array filled with NaN values
default_right_bases = np.full((0, 2), np.nan) # 2D array filled with NaN values
# Initialize default return values
default_dists = np.nan
default_indices = [(np.nan, np.nan)] # List of tuples with NaN values
default_left_bases = np.full((1, 2), np.nan) # 2D array filled with NaN values
default_right_bases = np.full((1, 2), np.nan) # 2D array filled with NaN values

# Check for minimum length, or all NaNs in arrays
if (
Expand Down

0 comments on commit 9bc5e18

Please sign in to comment.