Skip to content

Commit

Permalink
Add enforcement for np.sort and np.argsort (#918)
Browse files Browse the repository at this point in the history
* set mergesort as default and disable unstable kinds

* add unittest

* formatting

* formatting

* change name to sort_enforcement

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* break long error messages

* keep the original sorting in numpy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reemove unused import

* always use stablesort

* add numba-supported version of stableargsort

* use better naming for stablesort

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use jitable to allow both regular function and numba-decorated function for highest_density_region

* remove redundant numba_sort

* explicitly import stablesort from strax for numba decorated functions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* consistent import style within one module

* remove unused import

* add sorting error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable numba support for stable_sort

* consistent import style for stable sort

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add kwargs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify docstring for stable_sort

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove kwargs

* update variable name

* update test_sort with hypothesis

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rewrite hithest_density_region to decoupld stable_sort from numba part

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused import

* break long lines

* remove numba decorator for the main function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* rewrite hitlets to use non-numba HDR region

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* format hitlets.py

* unify growing_result import to fix mypy error

* remove redundant space

* Remove unnecessary indent

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: dachengx <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2024
1 parent c3dd2e1 commit 8489aa2
Show file tree
Hide file tree
Showing 14 changed files with 356 additions and 152 deletions.
1 change: 1 addition & 0 deletions strax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Glue the package together
# See https://www.youtube.com/watch?v=0oTh1CXRaQ0 if this confuses you
# The order of subpackes is not invariant, since we use strax.xxx inside strax
from .sort_enforcement import *
from .utils import *
from .chunk import *
from .dtypes import *
Expand Down
11 changes: 6 additions & 5 deletions strax/processing/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# for these fundamental functions, we throw warnings each time they are called

import strax
from strax import stable_sort, stable_argsort
import numba
from numba.typed import List
import numpy as np
Expand Down Expand Up @@ -37,23 +38,23 @@ def sort_by_time(x):
# Faster sorting:
x = _sort_by_time_and_channel(x, channel, channel.max() + 1)
elif "channel" in x.dtype.names:
x = np.sort(x, order=("time", "channel"))
x = stable_sort(x, order=("time", "channel"))
else:
x = np.sort(x, order=("time",))
x = stable_sort(x, order=("time",))
return x


@numba.jit(nopython=True, nogil=True, cache=True)
def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind="mergesort"):
"""Assumes you have no more than 10k channels, and records don't span more than 11 days.
(5-10x) faster than np.sort(order=...), as np.sort looks at all fields
(5-10x) faster than strax.stable_sort(order=...), as strax.stable_sort looks at all fields
"""
# I couldn't get fast argsort on multiple keys to work in numba
# So, let's make a single key...
sort_key = (x["time"] - x["time"].min()) * max_channel_plus_one + channel
sort_i = np.argsort(sort_key, kind=sort_kind)
sort_i = stable_argsort(sort_key, kind=sort_kind)
return x[sort_i]


Expand Down Expand Up @@ -426,7 +427,7 @@ def _touching_windows(
thing_start, thing_end, container_start, container_end, window=0, endtime_sort_kind="mergesort"
):
n = len(thing_start)
container_end_argsort = np.argsort(container_end, kind=endtime_sort_kind)
container_end_argsort = stable_argsort(container_end, kind=endtime_sort_kind)

# we search twice, first for the beginning of the interval, then for the end
left_i = right_i = 0
Expand Down
103 changes: 59 additions & 44 deletions strax/processing/hitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def concat_overlapping_hits(hits, extensions, pmt_channels, start, end):
return hits


@strax.utils.growing_result(strax.hit_dtype, chunk_size=int(1e4))
@strax.growing_result(strax.hit_dtype, chunk_size=int(1e4))
@numba.njit(nogil=True, cache=True)
def _concat_overlapping_hits(
hits,
Expand Down Expand Up @@ -499,23 +499,56 @@ def _conditional_entropy(hitlets, template, flat=False, square_data=False):
return res


@export
@numba.njit(cache=True)
def _compute_simple_edges(interval_indices, dt):
"""Compute edges without fractional edges using numba."""
left = interval_indices[0, 0] * dt
right = interval_indices[1, np.argmax(interval_indices[1, :])] * dt
return left, right


@export
@numba.njit(cache=True)
def _compute_fractional_edges(interval_indices, data, area_fraction_amplitude, dt):
"""Compute edges with fractional consideration using numba."""
left = interval_indices[0, 0]
right = interval_indices[1, np.argmax(interval_indices[1, :])] - 1

left_amp = data[left]
right_amp = data[right]

next_left_amp = 0
if (left - 1) >= 0:
next_left_amp = data[left - 1]
next_right_amp = 0
if (right + 1) < len(data):
next_right_amp = data[right + 1]

fl = (left_amp - area_fraction_amplitude) / (left_amp - next_left_amp)
fr = (right_amp - area_fraction_amplitude) / (right_amp - next_right_amp)

left_edge = (left + 0.5 - fl) * dt
right_edge = (right + 0.5 + fr) * dt
return left_edge, right_edge


@export
def highest_density_region_width(
data, fractions_desired, dt=1, fractionl_edges=False, _buffer_size=100
data, fractions_desired, dt=1, fractional_edges=False, _buffer_size=100
):
"""Function which computes the left and right edge based on the outer most sample for the
highest density region of a signal.
Defines a 100% fraction as the sum over all positive samples in a waveform.
Args:
data: Data of a signal, e.g. hitlet or peak including zero length encoding
fractions_desired: Area fractions for which HDR should be computed
dt: Sample length in ns
fractional_edges: If true computes width as fractional time
_buffer_size: Maximal number of allowed intervals
:param data: Data of a signal, e.g. hitlet or peak including zero length encoding.
:param fractions_desired: Area fractions for which the highest density region should be
computed.
:param dt: Sample length in ns.
:param fractionl_edges: If true computes width as fractional time depending on the covered area
between the current and next sample.
:param _buffer_size: Maximal number of allowed intervals. If signal exceeds number e.g. due to
noise width computation is skipped.
Returns:
np.ndarray: Array of shape (len(fractions_desired), 2) containing left and right edges
"""
res = np.zeros((len(fractions_desired), 2), dtype=np.float32)
Expand All @@ -525,49 +558,31 @@ def highest_density_region_width(
res[:] = np.nan
return res

inter, amps = strax.highest_density_region(
# Use the pure-python implementation for HDR computation
intervals, amps = strax.highest_density_region(
data,
fractions_desired,
only_upper_part=True,
_buffer_size=_buffer_size,
)

for index_area_fraction, (interval_indicies, area_fraction_amplitude) in enumerate(
zip(inter, amps)
# Deal with each area fraction separately
for index_area_fraction, (interval_indices, area_fraction_amplitude) in enumerate(
zip(intervals, amps)
):
if np.all(interval_indicies[:] == -1):
if np.all(interval_indices[:] == -1):
res[index_area_fraction, :] = np.nan
continue

if not fractionl_edges:
res[index_area_fraction, 0] = interval_indicies[0, 0] * dt
res[index_area_fraction, 1] = (
interval_indicies[1, np.argmax(interval_indicies[1, :])] * dt
)
if not fractional_edges:
left, right = _compute_simple_edges(interval_indices, dt)
res[index_area_fraction, 0] = left
res[index_area_fraction, 1] = right
else:
left = interval_indicies[0, 0]
# -1 since value corresponds to outer edge:
right = interval_indicies[1, np.argmax(interval_indicies[1, :])] - 1

# Get amplitudes of outer most samples
# and amplitudes of adjacent samples (if any)
left_amp = data[left]
right_amp = data[right]

next_left_amp = 0
if (left - 1) >= 0:
next_left_amp = data[left - 1]
next_right_amp = 0
if (right + 1) < len(data):
next_right_amp = data[right + 1]

# Compute fractions and new left and right edges, the case
# left_amp == next_left_amp cannot occure by the definition
# of the highest density region.
fl = (left_amp - area_fraction_amplitude) / (left_amp - next_left_amp)
fr = (right_amp - area_fraction_amplitude) / (right_amp - next_right_amp)

res[index_area_fraction, 0] = (left + 0.5 - fl) * dt
res[index_area_fraction, 1] = (right + 0.5 + fr) * dt
left, right = _compute_fractional_edges(
interval_indices, data, area_fraction_amplitude, dt
)
res[index_area_fraction, 0] = left
res[index_area_fraction, 1] = right

return res
5 changes: 2 additions & 3 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import numba

import strax
from strax import utils
from strax.dtypes import peak_dtype, DIGITAL_SUM_WAVEFORM_CHANNEL
from strax.dtypes import DIGITAL_SUM_WAVEFORM_CHANNEL

export, __all__ = strax.exporter()


@export
@utils.growing_result(dtype=peak_dtype(), chunk_size=int(1e4))
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True, cache=True)
def find_peaks(
hits,
Expand Down
2 changes: 1 addition & 1 deletion strax/processing/peak_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compute_widths(peaks, select_peaks_indices=None):
desired_fr = np.concatenate([0.5 - desired_widths / 2, 0.5 + desired_widths / 2])

# We lose the 50% fraction with this operation, let's add it back
desired_fr = np.sort(np.unique(np.append(desired_fr, [0.5])))
desired_fr = strax.stable_sort(np.unique(np.append(desired_fr, [0.5])))

fr_times = index_of_fraction(peaks[select_peaks_indices], desired_fr)
fr_times *= peaks["dt"][select_peaks_indices].reshape(-1, 1)
Expand Down
Loading

0 comments on commit 8489aa2

Please sign in to comment.