Skip to content

Commit

Permalink
Small fix to find_global_peaks_rough (#28)
Browse files Browse the repository at this point in the history
* find_global_peaks_rough_fix

* lint fix

* ruff lint fix

* Format file

---------

Co-authored-by: gitttt-1234 <[email protected]>
  • Loading branch information
alckasoc and gitttt-1234 authored Dec 15, 2023
1 parent 03b404d commit ff6801b
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions sleap_nn/inference/peak_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,29 @@ def find_global_peaks_rough(
Returns:
A tuple of (peak_points, peak_vals).
peak_points: float32 tensor of shape (samples, channels, 2), where the last axis
indicates peak locations in xy order.
peak_vals: float32 tensor of shape (samples, channels) containing the values at
the peak points.
"""
# Find the maximum values and their indices along the height and width axes.
max_values, max_indices_y = torch.max(cms, dim=2, keepdim=True)
max_values, max_indices_x = torch.max(max_values, dim=3, keepdim=True)

max_indices_x = max_indices_x.squeeze(dim=(2, 3)) # (samples, channels)
max_indices_y = max_indices_y.max(dim=3).values # (samples, channels, 1)
max_values = max_values.squeeze(-1).squeeze(-1) # (samples, channels)
peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1).to(
torch.float32
)

# Find the maximum values and their indices along the height and width axes.
amax_values, amax_indices_x = torch.max(cms, dim=3, keepdim=True)
amax_values, amax_indices_y = torch.max(amax_values, dim=2, keepdim=True)
amax_indices_y = amax_indices_y.squeeze(dim=(2, 3))
peak_points = torch.cat(
[max_indices_x.unsqueeze(-1), amax_indices_y.unsqueeze(-1)], dim=-1
).to(torch.float32)
max_values = max_values.squeeze(-1).squeeze(-1)
# Create masks for values below the threshold.
below_threshold_mask = max_values < threshold

# Replace values below the threshold with NaN.
peak_points[below_threshold_mask] = float("nan")

max_values[below_threshold_mask] = float(0)
return peak_points, max_values


Expand Down

0 comments on commit ff6801b

Please sign in to comment.