Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filtering #329

Merged
merged 48 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
c816da4
Adding filtering.py
gviejo Aug 2, 2024
3ad5ad9
Merge remote-tracking branch 'origin/change_slice_api' into filtering
BalzaniEdoardo Aug 7, 2024
ebe5e76
improved filtering
BalzaniEdoardo Aug 7, 2024
bc20031
added regressoin test
BalzaniEdoardo Aug 7, 2024
77dab5f
added unit testing
BalzaniEdoardo Aug 7, 2024
86ca588
test dytpe
BalzaniEdoardo Aug 7, 2024
098de14
removed simple test
BalzaniEdoardo Aug 7, 2024
287523e
improved tests
BalzaniEdoardo Aug 7, 2024
fca58a6
Merge branch 'dev' into filtering
gviejo Aug 13, 2024
d1851a9
do not run actions on draft PR
BalzaniEdoardo Aug 15, 2024
b6f8724
linted
BalzaniEdoardo Aug 15, 2024
3467a85
added tests
BalzaniEdoardo Aug 15, 2024
9a1ed8a
switch to sos filter
BalzaniEdoardo Aug 15, 2024
b1916bf
removed unused import
BalzaniEdoardo Aug 15, 2024
f669f95
linted
BalzaniEdoardo Aug 15, 2024
6c5d88a
few changes
gviejo Aug 15, 2024
4b34f23
Merge branch 'filtering' of github.com:pynapple-org/pynapple into fil…
gviejo Aug 15, 2024
009477a
Update pynapple/process/filtering.py
BalzaniEdoardo Aug 16, 2024
0d21879
Update pynapple/process/filtering.py
BalzaniEdoardo Aug 16, 2024
292ee29
changed to sampling_frequency
BalzaniEdoardo Aug 16, 2024
58e3ca1
linting
gviejo Aug 16, 2024
58a9554
Adding tests and notebook for filtering module
gviejo Sep 1, 2024
4b12da6
Fixing notebooks
gviejo Sep 2, 2024
eb93976
Fixing sinc
gviejo Sep 2, 2024
4cb76ce
added exception
BalzaniEdoardo Sep 6, 2024
e9468b0
added test exception
BalzaniEdoardo Sep 6, 2024
9774346
fixes lowpass and highpass by normalizing
BalzaniEdoardo Sep 6, 2024
55105f2
fix tests
BalzaniEdoardo Sep 6, 2024
2c9a194
fix highpass
BalzaniEdoardo Sep 6, 2024
c3a681a
fix highpass
BalzaniEdoardo Sep 6, 2024
affd113
fix bandstop
BalzaniEdoardo Sep 6, 2024
031cae4
fix bandpass
BalzaniEdoardo Sep 7, 2024
bb9ed18
added a parameter for spectral density
BalzaniEdoardo Sep 7, 2024
d96087e
edited tutorial
BalzaniEdoardo Sep 7, 2024
8300626
Update
gviejo Sep 9, 2024
c19d785
linted
BalzaniEdoardo Sep 9, 2024
ef948f0
removed repeated code. Simplified validate
BalzaniEdoardo Sep 9, 2024
d76aa3d
Update
gviejo Sep 9, 2024
900c4b5
nan fix
BalzaniEdoardo Sep 9, 2024
bbf6298
fix all nans
BalzaniEdoardo Sep 9, 2024
6af28c0
added description of bands. Fixed typing
BalzaniEdoardo Sep 11, 2024
6309c92
adding jax compatibility
gviejo Sep 11, 2024
6bb18bb
Merge branch 'filtering' of github.com:pynapple-org/pynapple into fil…
gviejo Sep 11, 2024
5f3eb77
improved flow
BalzaniEdoardo Sep 11, 2024
1309dab
Merge branch 'filtering' of github.com:pynapple-org/pynapple into fil…
gviejo Sep 12, 2024
d4d4941
Fix tests for pynajax
gviejo Sep 12, 2024
764d6b6
Updating
gviejo Sep 12, 2024
a86090c
Update
gviejo Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ on:
branches: [ main ]
pull_request:
branches: [ main, dev ]
types:
gviejo marked this conversation as resolved.
Show resolved Hide resolved
- opened
- reopened
- synchronize
- ready_for_review

jobs:
lint:
Expand Down
1 change: 1 addition & 0 deletions pynapple/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
compute_eventcorrelogram,
)
from .decoding import decode_1d, decode_2d
from .filtering import compute_filtered_signal
from .perievent import (
compute_event_trigger_average,
compute_perievent,
Expand Down
91 changes: 91 additions & 0 deletions pynapple/process/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Filtering module."""

from numbers import Number

import numpy as np
from scipy.signal import butter, sosfiltfilt

from .. import core as nap


def compute_filtered_signal(
data, freq_band, filter_type="bandpass", order=4, sampling_frequency=None
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Apply a Butterworth filter to the provided signal.

This function performs bandpass filtering on time series data
using a Butterworth filter. The filter can be configured to be of
type "bandpass", "bandstop", "highpass", or "lowpass".

Parameters
----------
data : Tsd, TsdFrame, or TsdTensor
The signal to be filtered.
freq_band : tuple of (float, float) or float
Cutoff frequency(ies) in Hz.
- For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies.
- For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency.
filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional
The type of frequency filter to apply. Default is "bandpass".
order : int, optional
The order of the Butterworth filter. Higher values result in sharper frequency cutoffs.
Default is 4.
sampling_frequency : float, optional
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data.

Returns
-------
filtered_data : Tsd, TsdFrame, or TsdTensor
The filtered signal, with the same data type as the input.

Raises
------
ValueError
If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}.
If `freq_band` is not a float for "lowpass" and "highpass" filters.
If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters.

Notes
-----
The cutoff frequency is defined as the frequency at which the amplitude of the signal
is reduced by -3 dB (decibels).
"""
if sampling_frequency is None:
sampling_frequency = data.rate

if filter_type not in ["lowpass", "highpass", "bandpass", "bandstop"]:
raise ValueError(
f"Unrecognized filter type {filter_type}. "
"filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'."
)
elif filter_type in ["lowpass", "highpass"] and not isinstance(freq_band, Number):
raise ValueError(
"Low/high-pass filter specification requires a single frequency. "
f"{freq_band} provided instead!"
)
elif filter_type in ["bandpass", "bandstop"]:
try:
if len(freq_band) != 2 or not all(
isinstance(fq, Number) for fq in freq_band
):
raise ValueError

Check warning on line 72 in pynapple/process/filtering.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/filtering.py#L72

Added line #L72 was not covered by tests
except Exception:
raise ValueError(
"Band-pass/stop filter specification requires two frequencies. "
f"{freq_band} provided instead!"
)

sos = butter(
order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos"
)

out = np.zeros_like(data.d)
for ep in data.time_support:
slc = data.get_slice(start=ep.start[0], end=ep.end[0])
out[slc] = sosfiltfilt(sos, data.d[slc], axis=0)

kwargs = dict(t=data.t, d=out, time_support=data.time_support)
if isinstance(data, nap.TsdFrame):
kwargs["columns"] = data.columns
return data.__class__(**kwargs)
172 changes: 172 additions & 0 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import pytest
import pynapple as nap
import numpy as np
from scipy import signal
from contextlib import nullcontext as does_not_raise


@pytest.fixture
def sample_data():
# Create a sample Tsd data object
t = np.linspace(0, 1, 500)
d = np.sin(2 * np.pi * 10 * t) + np.random.normal(0, 0.5, t.shape)
time_support = nap.IntervalSet(start=[0], end=[1])
return nap.Tsd(t=t, d=d, time_support=time_support)


@pytest.mark.parametrize("freq", [10, 100])
@pytest.mark.parametrize("order", [2, 4, 6])
@pytest.mark.parametrize("btype", ["lowpass", "highpass"])
@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)])
@pytest.mark.parametrize(
"ep",
[
nap.IntervalSet(start=[0], end=[1]),
nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]),
nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1])
]
)
def test_filtering_single_freq_match_sci(freq, order, btype, shape: tuple, ep):

t = np.linspace(0, 1, shape[0])
y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape))

if len(shape) == 1:
tsd = nap.Tsd(t, y, time_support=ep)
elif len(shape) == 2:
tsd = nap.TsdFrame(t, y, time_support=ep)
else:
tsd = nap.TsdTensor(t, y, time_support=ep)
out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order)
sos = signal.butter(order, freq, fs=tsd.rate, btype=btype, output="sos")
out_sci = []
for iset in ep:
out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0))
out_sci = np.concatenate(out_sci, axis=0)
np.testing.assert_array_equal(out.d, out_sci)


@pytest.mark.parametrize("freq", [[10, 30], [100,150]])
@pytest.mark.parametrize("order", [2, 4, 6])
@pytest.mark.parametrize("btype", ["bandpass", "bandstop"])
@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)])
@pytest.mark.parametrize(
"ep",
[
nap.IntervalSet(start=[0], end=[1]),
nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]),
nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1])
]
)
def test_filtering_freq_band_match_sci(freq, order, btype, shape: tuple, ep):

t = np.linspace(0, 1, shape[0])
y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape))

if len(shape) == 1:
tsd = nap.Tsd(t, y, time_support=ep)
elif len(shape) == 2:
tsd = nap.TsdFrame(t, y, time_support=ep)
else:
tsd = nap.TsdTensor(t, y, time_support=ep)
out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order)
sos = signal.butter(order, freq, fs=tsd.rate, btype=btype, output="sos")
out_sci = []
for iset in ep:
out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0))
out_sci = np.concatenate(out_sci, axis=0)
np.testing.assert_array_equal(out.d, out_sci)


@pytest.mark.parametrize("freq", [10, 100])
@pytest.mark.parametrize("order", [2, 4, 6])
@pytest.mark.parametrize("btype", ["lowpass", "highpass"])
@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)])
@pytest.mark.parametrize(
"ep",
[
nap.IntervalSet(start=[0], end=[1]),
nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]),
nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1])
]
)
def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep):
t = np.linspace(0, 1, shape[0])
y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape))

if len(shape) == 1:
tsd = nap.Tsd(t, y, time_support=ep)
elif len(shape) == 2:
tsd = nap.TsdFrame(t, y, time_support=ep, columns=np.arange(10, 10 + y.shape[1]))
else:
tsd = nap.TsdTensor(t, y, time_support=ep)
out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order)
assert isinstance(out, type(tsd))
assert np.all(out.t == tsd.t)
assert np.all(out.time_support == tsd.time_support)
if isinstance(tsd, nap.TsdFrame):
assert np.all(tsd.columns == out.columns)


@pytest.mark.parametrize("freq", [[10, 30], [100, 150]])
@pytest.mark.parametrize("order", [2, 4, 6])
@pytest.mark.parametrize("btype", ["bandpass", "bandstop"])
@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)])
@pytest.mark.parametrize(
"ep",
[
nap.IntervalSet(start=[0], end=[1]),
nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]),
nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1])
]
)
def test_filtering_freq_band_dtype(freq, order, btype, shape: tuple, ep):
t = np.linspace(0, 1, shape[0])
y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape))

if len(shape) == 1:
tsd = nap.Tsd(t, y, time_support=ep)
elif len(shape) == 2:
tsd = nap.TsdFrame(t, y, time_support=ep, columns=np.arange(10, 10 + y.shape[1]))
else:
tsd = nap.TsdTensor(t, y, time_support=ep)
out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order)
assert isinstance(out, type(tsd))
assert np.all(out.t == tsd.t)
assert np.all(out.time_support == tsd.time_support)
if isinstance(tsd, nap.TsdFrame):
assert np.all(tsd.columns == out.columns)


@pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [
((5, 15), "bandpass", 4, does_not_raise()),
((5, 15), "bandstop", 4, does_not_raise()),
(10, "highpass", 4, does_not_raise()),
(10, "lowpass", 4, does_not_raise()),
((5, 15), "invalid_filter", 4, pytest.raises(ValueError, match="Unrecognized filter type")),
(10, "bandpass", 4, pytest.raises(ValueError, match="Band-pass/stop filter specification requires two frequencies")),
((5, 15), "highpass", 4, pytest.raises(ValueError, match="Low/high-pass filter specification requires a single frequency")),
(None, "bandpass", 4, pytest.raises(ValueError, match="Band-pass/stop filter specification requires two frequencies")),
((None, 1), "highpass", 4, pytest.raises(ValueError, match="Low/high-pass filter specification requires a single frequency"))
])
def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, expected_exception):
with expected_exception:
filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type, order)
if not expected_exception:
assert isinstance(filtered_data, type(sample_data))
assert filtered_data.d.shape == sample_data.d.shape


# Test with edge-case frequencies close to Nyquist frequency
@pytest.mark.parametrize("nyquist_fraction", [0.99, 0.999])
@pytest.mark.parametrize("order", [2, 4])
def test_filtering_nyquist_edge_case(nyquist_fraction, order, sample_data):
nyquist_freq = 0.5 * sample_data.rate
freq = nyquist_freq * nyquist_fraction

out = nap.filtering.compute_filtered_signal(
sample_data, freq_band=freq, filter_type="lowpass", order=order
)
assert isinstance(out, type(sample_data))
np.testing.assert_allclose(out.t, sample_data.t)
np.testing.assert_allclose(out.time_support, sample_data.time_support)
Loading