Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 5, 2024
1 parent 12f3ad3 commit ea0241c
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
causal=causal
causal=causal,
)


Expand Down Expand Up @@ -151,15 +151,15 @@ def get_traces(self, start_frame, end_frame, channel_indices):

if self.filter_mode == "sos":
if causal:
filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk)))
filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk)))
else:
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
if causal:
filtered_traces = np.flip(scipy.signal.lfilter(b, a, np.flip(traces_chunk), axis=0))
filtered_traces = np.flip(scipy.signal.lfilter(b, a, np.flip(traces_chunk), axis=0))
else:
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -287,10 +287,11 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)


class CausalFilterRecording(FilterRecording):
"""
Implements backwards causal filter to correct for hardware induced phase shift
Parameters
----------
recording : Recording
Expand All @@ -312,12 +313,23 @@ class CausalFilterRecording(FilterRecording):
filter_recording : CausalFilterRecording
The CausalFilterRecording recording extractor object
"""

name = "causal_filter"

def __init__(self, recording, band=[300.0], margin_ms=5.0, btype = "highpass",filter_order = 1, dtype=None,**filter_kwargs):
def __init__(
self, recording, band=[300.0], margin_ms=5.0, btype="highpass", filter_order=1, dtype=None, **filter_kwargs
):
FilterRecording.__init__(
self, recording, band=band, margin_ms=margin_ms, dtype=dtype, btype = btype, filter_order = filter_order,causal = True,
**filter_kwargs)
self,
recording,
band=band,
margin_ms=margin_ms,
dtype=dtype,
btype=btype,
filter_order=filter_order,
causal=True,
**filter_kwargs,
)
dtype = fix_dtype(recording, dtype)
self._kwargs = dict(recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str)
self._kwargs.update(filter_kwargs)
Expand Down

0 comments on commit ea0241c

Please sign in to comment.