From 9bc5e1854a927e2f31e6cdcb31502b4a11473f5d Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Fri, 29 Mar 2024 20:47:01 -0700 Subject: [PATCH] Return NaNs as default in root widths --- sleap_roots/bases.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/sleap_roots/bases.py b/sleap_roots/bases.py index c3f3e09..ac7aafc 100644 --- a/sleap_roots/bases.py +++ b/sleap_roots/bases.py @@ -245,23 +245,22 @@ def get_root_widths( Returns: - If `return_inds` is False (default): - Returns an array of distances between the bases of matched roots. An empty - array is returned if no matching indices are found. + Returns an array of distances between the bases of matched roots. If no + matched indices are found, NaN is returned. - If `return_inds` is True: Returns a tuple containing the following four elements: - - matched_dists: Distances between the bases of matched roots. An empty - array is returned if no matched indices - are found. + - matched_dists: Distances between the bases of matched roots. If no + matched indices are found, NaN is returned. - matched_indices: List of tuples, each containing the indices of matched roots on the left and right sides. A list containing a tuple of NaNs is returned if no matched indices are found. - left_bases_final: (n, 2) array containing the (x, y) - coordinates of the left bases of the matched roots. An empty array - of shape (0, 2) is returned if no matched indices are found. + coordinates of the left bases of the matched roots. An array of + NaNs is returned if no matched indices are found. - right_bases_final: (n, 2) array containing the (x, y) - coordinates of the right bases of the matched roots. An empty array - of shape (0, 2) is returned if no matched indices are found. + coordinates of the right bases of the matched roots. An array of + NaNs is returned if no matched indices are found. """ # Validate tolerance if tolerance <= 0: @@ -275,11 +274,11 @@ def get_root_widths( if primary_max_length_pts.shape[1] != 2 or lateral_pts.shape[2] != 2: raise ValueError("The last dimension should contain x and y coordinates") - # Initialize default return values with shapes that match the expected output - default_dists = np.full((0,), np.nan) # Array filled with NaN values - default_indices = [(np.nan, np.nan)] - default_left_bases = np.full((0, 2), np.nan) # 2D array filled with NaN values - default_right_bases = np.full((0, 2), np.nan) # 2D array filled with NaN values + # Initialize default return values + default_dists = np.nan + default_indices = [(np.nan, np.nan)] # List of tuples with NaN values + default_left_bases = np.full((1, 2), np.nan) # 2D array filled with NaN values + default_right_bases = np.full((1, 2), np.nan) # 2D array filled with NaN values # Check for minimum length, or all NaNs in arrays if (