diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0c779db4..8c991883 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,6 +6,11 @@ on: branches: [ main ] pull_request: branches: [ main, dev ] + types: + - opened + - reopened + - synchronize + - ready_for_review jobs: lint: diff --git a/.gitignore b/.gitignore index 04a63643..d88327c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,15 @@ +*.npz *.nwb *.pickle *.py.md5 -*.npz +#*.npz /docs/generated/gallery/*.md /docs/generated/gallery/*.ipynb /docs/generated/gallery/*.py /docs/generated/gallery/*.zip +/tests/npzfilestest + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 965b2dda..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,11 +0,0 @@ -include AUTHORS.rst -include CONTRIBUTING.rst -include HISTORY.rst -include LICENSE -include README.rst - -recursive-include tests * -recursive-exclude * __pycache__ -recursive-exclude * *.py[co] - -recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/README.md b/README.md index 9f7a984b..83638ec0 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,15 @@ pynapple is a light-weight python library for neurophysiological data analysis. New release :fire: ------------------ +### pynapple >= 0.7 + +Pynapple now implements signal processing. For example, to filter a 1250 Hz sampled time series between 10 Hz and 20 Hz: + +```python +nap.apply_bandpass_filter(signal, (10, 20), fs=1250) +``` +New functions includes power spectral density and Morlet wavelet decomposition. See the [documentation](https://pynapple-org.github.io/pynapple/reference/process/) for more details. + ### pynapple >= 0.6 Starting with 0.6, [`IntervalSet`](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) objects are behaving as immutable numpy ndarray. Before 0.6, you could select an interval within an `IntervalSet` object with: @@ -45,8 +54,6 @@ With pynapple>=0.6, the slicing is similar to numpy and it returns an `IntervalS new_intervalset = intervalset[0] ``` -See the [documentation](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) for more details. - ### pynapple >= 0.4 Starting with 0.4, pynapple rely on the [numpy array container](https://numpy.org/doc/stable/user/basics.dispatch.html) approach instead of Pandas for the time series. Pynapple builtin functions will remain the same except for functions inherited from Pandas. diff --git a/docs/AUTHORS.md b/docs/AUTHORS.md index 6d9ff76d..d6182d39 100644 --- a/docs/AUTHORS.md +++ b/docs/AUTHORS.md @@ -5,14 +5,15 @@ Development Lead ---------------- - Guillaume Viejo +- Edoardo Balzani Contributors ------------ -- Edoardo Balzani - Adrien Peyrache - Dan Levenstein - Sofia Skromne Carrasco - Davide Spalla -- Luigi Petrucco \ No newline at end of file +- Luigi Petrucco + - ... [and many more!](https://github.com/pynapple-org/pynapple/graphs/contributors) \ No newline at end of file diff --git a/docs/HISTORY.md b/docs/HISTORY.md index 07ba8ce0..a64d469e 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -6,7 +6,29 @@ Another postdoc in the lab, Francesco Battaglia, then made major contributions t Around 2016-2017, Luke Sjulson started *TSToolbox2*, still in Matlab and which includes some important changes. In 2018, Francesco started neuroseries, a Python package built on Pandas. It was quickly adopted in Adrien's lab, especially by Guillaume Viejo, a postdoc in the lab. Gradually, the majority of the lab was using it and new functions were constantly added. -In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. +In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. +The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. + +Since 2023, the development of pynapple is lead by [Guillaume Viejo](https://www.simonsfoundation.org/people/guillaume-viejo/) +and [Edoardo Balzani](https://www.simonsfoundation.org/people/edoardo-balzani/) at the Center for Computational Neuroscience +of the Flatiron institute. + + + +0.7.0 (2024-09-16) +------------------ + +- Morlet wavelets spectrogram with utility for plotting the wavelets. +- (Mean) Power spectral density. Returns a Pandas DataFrame. +- Convolve function works for any dimension of time series and any dimensions of kernel. +- `dtype` in count function +- `get_slice`: public method with a simplified API, argument start, end, time_units. returns a slice that matches behavior of Base.get. +- `_get_slice`: private method, adds the argument "mode" this can be: "after_t", "before_t", "closest_t", "restrict". +- `split` method for IntervalSet. Argument is `interval_size` in time unit. +- Changed os import to pathlib. +- Fixed pickling issue. TsGroup can now be saved as pickle. +- TsGroup can be created from an iterable of Ts/Tsd objects. +- IntervalSet can be created from (start, end) pairs 0.6.6 (2024-05-28) ------------------ diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py new file mode 100644 index 00000000..87516578 --- /dev/null +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- +""" +Filtering +========= + +The filtering module holds the functions for frequency manipulation : + +- `nap.apply_bandstop_filter` +- `nap.apply_lowpass_filter` +- `nap.apply_highpass_filter` +- `nap.apply_bandpass_filter` + +The functions have similar calling signatures. For example, to filter a 1000 Hz signal between +10 and 20 Hz using a Butterworth filter: + +```{python} +>>> new_tsd = nap.apply_bandpass_filter(tsd, (10, 20), fs=1000, mode='butter') +``` + +Currently, the filtering module provides two methods for frequency manipulation: `butter` +for a recursive Butterworth filter and `sinc` for a Windowed-sinc convolution. This notebook provides +a comparison of the two methods. +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# mkdocs_gallery_thumbnail_number = 1 +# +# Now, import the necessary libraries: + +import matplotlib.pyplot as plt +import numpy as np +import seaborn + +import pynapple as nap + +custom_params = {"axes.spines.right": False, "axes.spines.top": False} +seaborn.set_theme(context='notebook', style="ticks", rc=custom_params) + +# %% +# *** +# Introduction +# ------------ +# +# We start by generating a signal with multiple frequencies (2, 10 and 50 Hz). +fs = 1000 # sampling frequency +t = np.linspace(0, 2, fs * 2) +f2 = np.cos(t*2*np.pi*2) +f10 = np.cos(t*2*np.pi*10) +f50 = np.cos(t*2*np.pi*50) + +sig = nap.Tsd(t=t,d=f2+f10+f50 + np.random.normal(0, 0.5, len(t))) + +# %% +# Let's plot it +fig = plt.figure(figsize = (15, 5)) +plt.plot(sig) +plt.xlabel("Time (s)") + + +# %% +# We can compute the Fourier transform of `sig` to verify that all the frequencies are there. +psd = nap.compute_power_spectral_density(sig, fs, norm=True) + +fig = plt.figure(figsize = (15, 5)) +plt.plot(np.abs(psd)) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 100) + + +# %% +# Let's say we would like to see only the 10 Hz component. +# We can use the function `apply_bandpass_filter` with mode `butter` for Butterworth. + +sig_butter = nap.apply_bandpass_filter(sig, (8, 12), fs, mode='butter') + +# %% +# Let's compare it to the `sinc` mode for Windowed-sinc. +sig_sinc = nap.apply_bandpass_filter(sig, (8, 12), fs, mode='sinc', transition_bandwidth=0.003) + +# %% +# Let's plot it +fig = plt.figure(figsize = (15, 5)) +plt.subplot(211) +plt.plot(t, f10, '-', color = 'gray', label = "10 Hz component") +plt.xlim(0, 1) +plt.legend() +plt.subplot(212) +# plt.plot(sig, alpha=0.5) +plt.plot(sig_butter, label = "Butterworth") +plt.plot(sig_sinc, '--', label = "Windowed-sinc") +plt.legend() +plt.xlabel("Time (s)") +plt.xlim(0, 1) + + +# %% +# This gives similar results except at the edges. +# +# Another use of filtering is to remove some frequencies. Here we can try to remove +# the 50 Hz component in the signal. + +sig_butter = nap.apply_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') +sig_sinc = nap.apply_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc', transition_bandwidth=0.004) + + +# %% +# Let's plot it +fig = plt.figure(figsize = (15, 5)) +plt.subplot(211) +plt.plot(t, sig, '-', color = 'gray', label = "Original signal") +plt.xlim(0, 1) +plt.legend() +plt.subplot(212) +plt.plot(sig_butter, label = "Butterworth") +plt.plot(sig_sinc, '--', label = "Windowed-sinc") +plt.legend() +plt.xlabel("Time (Hz)") +plt.xlim(0, 1) + + +# %% +# Let's see what frequencies remain; + +psd_butter = nap.compute_power_spectral_density(sig_butter, fs, norm=True) +psd_sinc = nap.compute_power_spectral_density(sig_sinc, fs, norm=True) + +fig = plt.figure(figsize = (10, 5)) +plt.plot(np.abs(psd_butter), label = "Butterworth filter") +plt.plot(np.abs(psd_sinc), label = "Windowed-sinc convolution") +plt.legend() +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 70) + + +# %% +# The remaining notebook compares the two modes. +# +# *** +# Frequency Responses +# ------------------- +# +# We can inspect the frequency response of a filter by plotting its power spectral density (PSD). +# To do this, we can use the `get_filter_frequency_response` function, which returns a pandas Series with the frequencies +# as the index and the PSD as values. +# +# Let's extract the frequency response of a Butterworth filter and a sinc low-pass filter. + +# compute the frequency response of the filters +psd_butter = nap.get_filter_frequency_response( + 200, fs,"lowpass", "butter", order=8 +) +psd_sinc = nap.get_filter_frequency_response( + 200, fs,"lowpass", "sinc", transition_bandwidth=0.1 +) + +# %% +# ...and plot it. + +# compute the transition bandwidth +tb_butter = psd_butter[psd_butter > 0.99].index.max(), psd_butter[psd_butter < 0.01].index.min() +tb_sinc = psd_sinc[psd_sinc > 0.99].index.max(), psd_sinc[psd_sinc < 0.01].index.min() + +fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(15, 5)) +fig.suptitle("Frequency response", fontsize="x-large") +axs[0].set_title("Butterworth Filter") +axs[0].plot(psd_butter) +axs[0].axvspan(0, tb_butter[0], alpha=0.4, color="green", label="Pass Band") +axs[0].axvspan(*tb_butter, alpha=0.4, color="orange", label="Transition Band") +axs[0].axvspan(tb_butter[1], 500, alpha=0.4, color="red", label="Stop Band") +axs[0].legend().get_frame().set_alpha(1.) +axs[0].set_xlim(0, 500) +axs[0].set_xlabel("Frequency (Hz)") +axs[0].set_ylabel("Amplitude") + +axs[1].set_title("Sinc Filter") +axs[1].plot(psd_sinc) +axs[1].axvspan(0, tb_sinc[0], alpha=0.4, color="green", label="Pass Band") +axs[1].axvspan(*tb_sinc, alpha=0.4, color="orange", label="Transition Band") +axs[1].axvspan(tb_sinc[1], 500, alpha=0.4, color="red", label="Stop Band") +axs[1].legend().get_frame().set_alpha(1.) +axs[1].set_xlabel("Frequency (Hz)") + +print(f"Transition band butterworth filter: ({int(tb_butter[0])}Hz, {int(tb_butter[1])}Hz)") +print(f"Transition band sinc filter: ({int(tb_sinc[0])}Hz, {int(tb_sinc[1])}Hz)") + +# %% +# The frequency band with response close to one will be preserved by the filtering (pass band), +# the band with response close to zero will be discarded (stop band), and the band in between will be partially attenuated +# (transition band). +# +# ??? note "Transition Bandwidth (Click to expand/collapse)" +# Here, we define the transition band as the range where the amplitude attenuation is between 99% and 1%. +# The `transition_bandwidth` parameter of the sinc filter is approximately the width of the transition +# band normalized by the sampling frequency. In the example above, if you divide the transition band width +# of 122Hz by the sampling frequency of 1000Hz, you get 0.122, which is close to the 0.1 value set. +# +# You can modulate the width of the transition band by setting the `order` parameter of the Butterworth filter +# or the `transition_bandwidth` parameter of the sinc filter. +# First, let's get the frequency response for a Butterworth low pass filter with different order: + +butter_freq = { + order: nap.get_filter_frequency_response(250, fs, "lowpass", "butter", order=order) + for order in [2, 4, 6]} + +# %% +# ... and then the frequency response for the Windowed-sinc equivalent with different transition bandwidth. +sinc_freq = { + tb: nap.get_filter_frequency_response(250, fs,"lowpass", "sinc", transition_bandwidth=tb) + for tb in [0.002, 0.02, 0.2]} + +# %% +# Let's plot the frequency response of both. + +fig = plt.figure(figsize = (20, 10)) +gs = plt.GridSpec(2, 2) +for order in butter_freq.keys(): + plt.subplot(gs[0, 0]) + plt.plot(butter_freq[order], label = f"order={order}") + plt.ylabel('Amplitude') + plt.legend() + plt.title("Butterworth recursive") + plt.subplot(gs[1, 0]) + plt.plot(20*np.log10(butter_freq[order]), label = f"order={order}") + plt.xlabel('Frequency [Hz]') + plt.ylabel('Amplitude [dB]') + plt.ylim(-200,20) + plt.legend() + +for tb in sinc_freq.keys(): + plt.subplot(gs[0, 1]) + plt.plot(sinc_freq[tb], label= f"width={tb}") + plt.ylabel('Amplitude') + plt.legend() + plt.title("Windowed-sinc conv.") + plt.subplot(gs[1, 1]) + plt.plot(20*np.log10(sinc_freq[tb]), label= f"width={tb}") + plt.xlabel('Frequency [Hz]') + plt.ylabel('Amplitude [dB]') + plt.ylim(-200,20) + plt.legend() + +# %% +# ⚠️ **Warning:** In some cases, the transition bandwidth that is too high generates a kernel that is too short. +# The amplitude of the original signal will then be lower than expected. +# In this case, the solution is to decrease the transition bandwidth when using the windowed-sinc mode. +# Note that this increases the length of the kernel significantly. +# Let see it with the band pass filter. + + +sinc_freq = { + tb:nap.get_filter_frequency_response((100, 200), fs, "bandpass", "sinc", transition_bandwidth=tb) + for tb in [0.004, 0.2]} + + +fig = plt.figure(figsize = (20, 10)) +for tb in sinc_freq.keys(): + plt.plot(sinc_freq[tb], label= f"width={tb}") + plt.ylabel('Amplitude') + plt.legend() + plt.title("Windowed-sinc conv.") + plt.legend() + + + +# %% +# *** +# Performances +# ------------ +# Let's compare the performance of each when varying the number of time points and the number of dimensions. +from time import perf_counter + +def get_mean_perf(tsd, mode, n=10): + tmp = np.zeros(n) + for i in range(n): + t1 = perf_counter() + _ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode) + t2 = perf_counter() + tmp[i] = t2 - t1 + return [np.mean(tmp), np.std(tmp)] + +def benchmark_time_points(mode): + times = [] + for T in np.arange(1000, 100000, 20000): + time_array = np.arange(T)/1000 + data_array = np.random.randn(len(time_array)) + startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2) + ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1]) + tsd = nap.Tsd(t=time_array, d=data_array, time_support=ep) + times.append([T]+get_mean_perf(tsd, mode)) + return np.array(times) + +def benchmark_dimensions(mode): + times = [] + for n in np.arange(1, 100, 10): + time_array = np.arange(10000)/1000 + data_array = np.random.randn(len(time_array), n) + startend = np.linspace(0, time_array[-1], 10000//100).reshape(10000//200, 2) + ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1]) + tsd = nap.TsdFrame(t=time_array, d=data_array, time_support=ep) + times.append([n]+get_mean_perf(tsd, mode)) + return np.array(times) + + +times_sinc = benchmark_time_points(mode="sinc") +times_butter = benchmark_time_points(mode="butter") + +dims_sinc = benchmark_dimensions(mode="sinc") +dims_butter = benchmark_dimensions(mode="butter") + + +plt.figure(figsize = (16, 5)) +plt.subplot(121) +for arr, label in zip( + [times_sinc, times_butter], + ["Windowed-sinc", "Butter"], + ): + plt.plot(arr[:, 0], arr[:, 1], "o-", label=label) + plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2) +plt.legend() +plt.xlabel("Number of time points") +plt.ylabel("Time (s)") +plt.title("Low pass filtering benchmark") +plt.subplot(122) +for arr, label in zip( + [dims_sinc, dims_butter], + ["Windowed-sinc", "Butter"], + ): + plt.plot(arr[:, 0], arr[:, 1], "o-", label=label) + plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2) +plt.legend() +plt.xlabel("Number of dimensions") +plt.ylabel("Time (s)") +plt.title("Low pass filtering benchmark") diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py new file mode 100644 index 00000000..39db2432 --- /dev/null +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +""" +Power spectral density +====================== + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# Now, import the necessary libraries: +# +# mkdocs_gallery_thumbnail_number = 3 + +import matplotlib.pyplot as plt +import numpy as np +import seaborn + +seaborn.set_theme() + +import pynapple as nap + +# %% +# *** +# Generating a signal +# ------------------ +# Let's generate a dummy signal with 2Hz and 10Hz sinusoide with white noise. +# + +F = [2, 10] + +Fs = 2000 +t = np.arange(0, 200, 1/Fs) +sig = nap.Tsd( + t=t, + d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 3, len(t)), + time_support = nap.IntervalSet(0, 200) + ) + +# %% +# Let's plot it +plt.figure() +plt.plot(sig.get(0, 0.4)) +plt.title("Signal") +plt.xlabel("Time (s)") + + + +# %% +# Computing power spectral density (PSD) +# -------------------------------------- +# +# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density`. With `norm=True`, the output of the FFT is divided by the length of the signal. + +psd = nap.compute_power_spectral_density(sig, norm=True) + +# %% +# Pynapple returns a pandas DataFrame. + +print(psd) + +# %% +# It is then easy to plot it. + +plt.figure() +plt.plot(np.abs(psd)) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") + + +# %% +# Note that the output of the FFT is truncated to positive frequencies. To get positive and negative frequencies, you can set `full_range=True`. +# By default, the function returns the frequencies up to the Nyquist frequency. +# Let's zoom on the first 20 Hz. + +plt.figure() +plt.plot(np.abs(psd)) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 20) + + +# %% +# We find the two frequencies 2 and 10 Hz. +# +# By default, pynapple assumes a constant sampling rate and a single epoch. For example, computing the FFT over more than 1 epoch will raise an error. +double_ep = nap.IntervalSet([0, 50], [20, 100]) + +try: + nap.compute_power_spectral_density(sig, ep=double_ep) +except ValueError as e: + print(e) + + +# %% +# Computing mean PSD +# ------------------ +# +# It is possible to compute an average PSD over multiple epochs with the function `nap.compute_mean_power_spectral_density`. +# +# In this case, the argument `interval_size` determines the duration of each epochs upon which the FFT is computed. +# If not epochs is passed, the function will split the `time_support`. +# +# In this case, the FFT will be computed over epochs of 10 seconds. + +mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=20.0, norm=True) + + +# %% +# Let's compare `mean_psd` to `psd`. In both cases, the ouput is normalized. + +plt.figure() +plt.plot(np.abs(psd), label='PSD') +plt.plot(np.abs(mean_psd), label='Mean PSD (10s)') +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.legend() +plt.xlim(0, 15) + +# %% +# As we can see, `nap.compute_mean_power_spectral_density` was able to smooth out the noise. + + diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py new file mode 100644 index 00000000..20246d0f --- /dev/null +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -0,0 +1,501 @@ +# -*- coding: utf-8 -*- +""" +Wavelet Transform +================= + +This tutorial covers the use of `nap.compute_wavelet_transform` to do continuous wavelet transform. By default, pynapple uses Morlet wavelets. + +Wavelet are a great tool for capturing changes of spectral characteristics of a signal over time. As neural signals change +and develop over time, wavelet decompositions can aid both visualization and analysis. + +The function `nap.generate_morlet_filterbank` can help parametrize and visualize the Morlet wavelets. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by [Kipp Freud](https://kippfreud.com/). + +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# mkdocs_gallery_thumbnail_number = 9 +# +# Now, import the necessary libraries: + +import matplotlib.pyplot as plt +import numpy as np +import seaborn + +seaborn.set_theme() + +import pynapple as nap + +# %% +# *** +# Generating a Dummy Signal +# ------------------ +# Let's generate a dummy signal to analyse with wavelets! +# +# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined +# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. + +Fs = 2000 +t = np.linspace(0, 5, Fs * 5) +two_hz_phase = t * 2 * np.pi * 2 +two_hz_component = np.sin(two_hz_phase) +increasing_freq_component = np.sin(t * (5 + t) * np.pi * 2) +sig = nap.Tsd( + d=two_hz_component + increasing_freq_component + np.random.normal(0, 0.1, 10000), + t=t, +) + +# %% +# Lets plot it. +fig, ax = plt.subplots(3, constrained_layout=True, figsize=(10, 5)) +ax[0].plot(t, two_hz_component) +ax[0].set_title("2Hz Component") +ax[1].plot(t, increasing_freq_component) +ax[1].set_title("Increasing Frequency Component") +ax[2].plot(sig) +ax[2].set_title("Dummy Signal") +[ax[i].margins(0) for i in range(3)] +[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] +[ax[i].set_xlabel("Time (s)") for i in range(3)] +[ax[i].set_ylabel("Signal") for i in range(3)] + + +# %% +# *** +# Getting our Morlet Wavelet Filter Bank +# ------------------ +# We will be decomposing our dummy signal using wavelets of different frequencies. These wavelets +# can be examined using the `generate_morlet_filterbank` function. Here we will use the default parameters +# to define a Morlet filter bank with which we will later use to deconstruct the signal. + +# Define the frequency of the wavelets in our filter bank +freqs = np.linspace(1, 25, num=25) +# Get the filter bank +filter_bank = nap.generate_morlet_filterbank( + freqs, Fs, gaussian_width=1.5, window_length=1.0 +) + + +# %% +# Lets plot it some of the wavelets. + + +def plot_filterbank(filter_bank, freqs, title): + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) + for f_i in range(filter_bank.shape[1]): + ax.plot(filter_bank[:, f_i].real() + f_i * 1.5) + ax.text(-5.5, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") + + ax.set_yticks([]) + ax.set_xlim(-5, 5) + ax.set_xlabel("Time (s)") + ax.set_title(title) + + +title = "Morlet Wavelet Filter Bank (Real Components): gaussian_width=1.5, window_length=1.0" +plot_filterbank(filter_bank, freqs, title) + +# %% +# *** +# Continuous wavelet transform +# ---------------------------- +# Here we will use the `compute_wavelet_transform` function to decompose our signal using the filter bank shown +# above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and +# frequency information for analysis. We will calculate this decomposition and plot it's corresponding +# scalogram (which is another name for time frequency decomposition using wavelets). + +# Compute the wavelet transform using the parameters above +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=1.5, window_length=1.0 +) + +# %% +# `mwt` for Morlet wavelet transform is a `TsdFrame`. Each column is the result of the convolution of the signal with one wavelet. + +print(mwt) + +# %% +# Lets plot it. + + +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(np.abs(powers), aspect="auto") + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + ax.get_xaxis().set_visible(False) + ax.set(yticks=[np.argmin(np.abs(freqs - val)) for val in freqs], yticklabels=freqs) + ax.grid(False) + return im + + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("Signal") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +# %% +# *** +# Reconstructing the Slow Oscillation and Phase +# ------------------ +# We can see that the decomposition has picked up on the 2Hz component of the signal, as well as the component with +# increasing frequency. In this section, we will extract just the 2Hz component from the wavelet decomposition, +# and see how it compares to the original section. + +# Get the index of the 2Hz frequency +two_hz_freq_idx = np.where(freqs == 2.0)[0] +# The 2Hz component is the real component of the wavelet decomposition at this index +slow_oscillation = np.real(mwt[:, two_hz_freq_idx]) +# The 2Hz wavelet phase is the angle of the wavelet decomposition at this index +slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx]) + +# %% +# Lets plot it. +fig = plt.figure(constrained_layout=True, figsize=(10, 4)) +axd = fig.subplot_mosaic( + [["signal"], ["phase"]], + height_ratios=[1, 0.4], +) +axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) +axd["signal"].plot(slow_oscillation, label="2Hz Reconstruction") +axd["signal"].legend() +axd["signal"].set_ylabel("Signal") + +axd["phase"].plot(slow_oscillation_phase, alpha=0.5) +axd["phase"].set_ylabel("Phase (rad)") +axd["phase"].set_xlabel("Time (s)") +[axd[k].margins(0) for k in ["signal", "phase"]] + +# %% +# *** +# Adding in the 15Hz Oscillation +# ------------------ +# Let's see what happens if we also add the 15 Hz component of the wavelet decomposition to the reconstruction. We +# will extract the 15 Hz components, and also the 15Hz wavelet power over time. The wavelet power tells us to what +# extent the 15 Hz frequency is present in our signal at different times. +# +# Finally, we will add this 15 Hz reconstruction to the one shown above, to see if it improves out reconstructed +# signal. + +# Get the index of the 15 Hz frequency +fifteen_hz_freq_idx = np.where(freqs == 15.0)[0] +# The 15 Hz component is the real component of the wavelet decomposition at this index +fifteenHz_oscillation = np.real(mwt[:, fifteen_hz_freq_idx]) +# The 15 Hz poser is the absolute value of the wavelet decomposition at this index +fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx]) + +# %% +# Lets plot it. + +fig = plt.figure(constrained_layout=True, figsize=(10, 4)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 1.0]) + +ax0 = plt.subplot(gs[0, 0]) +ax0.plot(fifteenHz_oscillation, label="15Hz Reconstruction") +ax0.plot(fifteenHz_oscillation_power, label="15Hz Power") +ax0.set_xticklabels([]) + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig, label="Raw Signal", alpha=0.5) +ax1.plot( + slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction" +) +ax1.set_xlabel("Time (s)") + +[ + (a.margins(0), a.legend(), a.set_ylim(-2.5, 2.5), a.set_ylabel("Signal")) + for a in [ax0, ax1] +] + + +# %% +# *** +# Adding ALL the Oscillations! +# ------------------ +# We will now learn how to interpret the parameters of the wavelet, and in particular how to trade off the +# accuracy in the frequency decomposition with the accuracy in the time domain reconstruction; + +# Up to this point we have used default wavelet and normalization parameters. +# +# Let's now add together the real components of all frequency bands to recreate a version of the original signal. + +combined_oscillations = np.real(np.sum(mwt, axis=1)) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.set_ylim(-6, 6) +ax.margins(0) +ax.legend() + + +# %% +# *** +# Parametrization +# ------------------ +# Our reconstruction seems to get the amplitude modulations of our signal correct, but the amplitude is overestimated, +# in particular towards the end of the time period. Often, this is due to a suboptimal choice of parameters, which +# can lead to a low spatial or temporal resolution. Let's visualize what changing our parameters does to the +# underlying wavelets. + +window_lengths = [1.0, 3.0] +gaussian_widths = [1.0, 3.0] +colors = np.array([["r", "g"], ["b", "y"]]) +fig, ax = plt.subplots( + len(window_lengths) + 1, + len(gaussian_widths) + 1, + constrained_layout=True, + figsize=(10, 8), +) +for row_i, wl in enumerate(window_lengths): + for col_i, gw in enumerate(gaussian_widths): + wavelet = nap.generate_morlet_filterbank( + np.array([1.0]), 1000, gaussian_width=gw, window_length=wl, precision=12 + )[:, 0].real() + ax[row_i, col_i].plot(wavelet, c=colors[row_i, col_i]) + fft = nap.compute_power_spectral_density(wavelet) + for i, j in [(row_i, -1), (-1, col_i)]: + ax[i, j].plot(fft.abs(), c=colors[row_i, col_i]) +for i in range(len(window_lengths)): + for j in range(len(gaussian_widths)): + ax[i, j].set(xlabel="Time (s)", yticks=[]) +for ci, gw in enumerate(gaussian_widths): + ax[0, ci].set_title(f"gaussian_width={gw}", fontsize=10) +for ri, wl in enumerate(window_lengths): + ax[ri, 0].set_ylabel(f"window_length={wl}", fontsize=10) +fig.suptitle("Parametrization Visualization (1 Hz Wavelet)") +ax[-1, -1].set_visible(False) +for i in range(len(window_lengths)): + ax[-1, i].set( + xlim=(0, 2), yticks=[], ylabel="Frequency Response", xlabel="Frequency (Hz)" + ) +for i in range(len(gaussian_widths)): + ax[i, -1].set( + xlim=(0, 2), yticks=[], ylabel="Frequency Response", xlabel="Frequency (Hz)" + ) + +# %% +# Increasing `window_length` increases the number of wavelet cycles present in the oscillations (cycles), and +# correspondingly increases the time window that the wavelet covers. +# +# The `gaussian_width` parameter determines the shape of the gaussian window being convolved with the sinusoidal +# component of the wavelet +# +# Both of these parameters can be tweaked to control for the trade-off between time resolution and frequency resolution. + +# %% +# *** +# Effect of `gaussian_width` +# ------------------ +# Let's increase `gaussian_width` to 7.5 and see the effect on the resultant filter bank. + +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width=7.5, window_length=1.0 +) + +plot_filterbank( + filter_bank, + freqs, + "Morlet Wavelet Filter Bank (Real Components): gaussian_width=7.5, center_frequency=1.0", +) + +# %% +# *** +# Let's see what effect this has on the Wavelet Scalogram which is generated... +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=1.0 +) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("Signal") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +# %% +# *** +# And let's see if that has an effect on the reconstructed version of the signal + +combined_oscillations = mwt.sum(axis=1).real() + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.set_ylim(-6, 6) +ax.margins(0) +ax.legend() + +# %% +# There's a small improvement, but perhaps we can do better. + + +# %% +# *** +# Effect of `window_length` +# ------------------ +# Let's increase `window_length` to 2.0 and see the effect on the resultant filter bank. + +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width=7.5, window_length=2.0 +) + +plot_filterbank( + filter_bank, + freqs, + "Morlet Wavelet Filter Bank (Real Components): gaussian_width=7.5, center_frequency=2.0", +) + +# %% +# *** +# Let's see what effect this has on the Wavelet Scalogram which is generated... +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0 +) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("Signal") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +# %% +# *** +# And let's see if that has an effect on the reconstructed version of the signal + +combined_oscillations = np.real(np.sum(mwt, axis=1)) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.set_ylim(-6, 6) +ax.legend() + + +# %% +# *** +# Effect of L1 vs L2 normalization +# ------------------ +# `compute_wavelet_transform` contains two options for normalization; L1, and L2. +# By default, L1 is used as it creates cleaner looking decomposition images. +# +# L1 normalization often increases the contrast between significant and insignificant coefficients. +# This can result in a sharper and more defined visual representation, making patterns and structures within +# the signal more evident. +# +# L2 normalization is directly related to the energy of the signal. By normalizing using the +# L2 norm, you ensure that the transformed coefficients preserve the energy distribution of the original signal. +# +# Let's compare two wavelet decomposition, each generated with a different normalization strategy + +mwt_l1 = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l1" +) +mwt_l2 = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l2" +) + +# %% +# Let's plot both the scalograms and see the difference. + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition - L1 Normalization") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_l1[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("Signal") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition - L2 Normalization") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_l2[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("Signal") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +# %% +# We see that the l1 normalized image contains a visually clearer image; the 5-15 Hz component of the signal is +# as powerful as the 2 Hz component, so it makes sense that they should be shown with the same power in the scalogram. +# Let's reconstruct the signal using both decompositions and see the resulting reconstruction... + +# %% + +combined_oscillations_l1 = np.real(np.sum(mwt_l1, axis=1)) +combined_oscillations_l2 = np.real(np.sum(mwt_l2, axis=1)) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, label="Signal", linewidth=3, alpha=0.6, c="b") +ax.plot(combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6) +ax.plot(combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.set_ylim(-6, 6) +ax.legend() + +# %% +# We see that the reconstruction from the L2 normalized decomposition matched the original signal much more closely, +# this is due to the fact that L2 normalization preserved the energy of the original signal in its reconstruction. diff --git a/docs/examples/tutorial_human_dataset.py b/docs/examples/tutorial_human_dataset.py index f84cbef6..caeb5f0e 100644 --- a/docs/examples/tutorial_human_dataset.py +++ b/docs/examples/tutorial_human_dataset.py @@ -189,36 +189,23 @@ # ------------------ # # Now that we have the PETH of spiking, we can go one step further. We will plot the mean firing rate of this cell aligned to the boundary for each trial type. Doing this in Pynapple is very simple! - -bin_size = 0.2 # 200ms bin size -step_size = 0.01 # 10ms step size, to make overlapping bins -winsize = int(bin_size / step_size) # Window size - -# %% +# # Use Pynapple to compute binned spike counts - -counts_NB = NB_peth.count(step_size) # Spike counts binned in 10ms steps, for NB trials -counts_HB = HB_peth.count(step_size) # Spike counts binned in 10ms steps, for HB trials +bin_size = 0.01 +counts_NB = NB_peth.count(bin_size) # Spike counts binned in 10ms steps, for NB trials +counts_HB = HB_peth.count(bin_size) # Spike counts binned in 10ms steps, for HB trials # %% -# Smooth the binned spike counts using a window of size 20, for both trial types +# Compute firing rate for both trial types -counts_NB = ( - counts_NB.as_dataframe() - .rolling(winsize, win_type="gaussian", min_periods=1, center=True, axis=0) - .mean(std=0.2 * winsize) -) -counts_HB = ( - counts_HB.as_dataframe() - .rolling(winsize, win_type="gaussian", min_periods=1, center=True, axis=0) - .mean(std=0.2 * winsize) -) +fr_NB = counts_NB / bin_size +fr_HB = counts_HB / bin_size # %% -# Compute firing rate for both trial types +# Smooth the firing rate with a gaussian window with std=4*bin_size +counts_NB = counts_NB.smooth(bin_size*4) +counts_HB = counts_HB.smooth(bin_size*4) -fr_NB = counts_NB * winsize -fr_HB = counts_HB * winsize # %% # Compute the mean firing rate for both trial types @@ -228,9 +215,9 @@ # %% # Compute standard error of mean (SEM) of the firing rate for both trial types - -error_NB = fr_NB.sem(axis=1) -error_HB = fr_HB.sem(axis=1) +from scipy.stats import sem +error_NB = sem(fr_NB, axis=1) +error_HB = sem(fr_HB, axis=1) # %% # Plot the mean +/- SEM of firing rate for both trial types diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py new file mode 100644 index 00000000..c500ca82 --- /dev/null +++ b/docs/examples/tutorial_phase_preferences.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +""" +Spikes-phase coupling +===================== + +In this tutorial we will learn how to isolate phase information using band-pass filtering and combine it +with spiking data, to find phase preferences of spiking units. + +Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. + +This tutorial was made by [Kipp Freud](https://kippfreud.com/) & Guillaume Viejo +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# mkdocs_gallery_thumbnail_number = 6 +# +# First, import the necessary libraries: + +import math +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import requests +import scipy +import seaborn +import tqdm + +custom_params = {"axes.spines.right": False, "axes.spines.top": False} +seaborn.set_theme(context='notebook', style="ticks", rc=custom_params) + +import pynapple as nap + +# %% +# *** +# Downloading the data +# ------------------ +# Let's download the data and save it locally + +path = "Achilles_10252013_EEG.nwb" +if path not in os.listdir("."): + r = requests.get(f"https://osf.io/2dfvp/download", stream=True) + block_size = 1024 * 1024 + with open(path, "wb") as f: + for data in tqdm.tqdm( + r.iter_content(block_size), + unit="MB", + unit_scale=True, + total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), + ): + f.write(data) + + +# %% +# *** +# Loading the data +# ------------------ +# Let's load and print the full dataset. + +data = nap.load_file(path) +FS = 1250 # We know from the methods of the paper +print(data) + + +# %% +# *** +# Selecting slices +# ----------------------------------- +# For later visualization, we define an interval of 3 seconds of data during REM sleep. + +ep_ex_rem = nap.IntervalSet( + data["rem"]["start"][0] + 97.0, + data["rem"]["start"][0] + 100.0, +) +# %% +# Here we restrict the lfp to the REM epochs. +tsd_rem = data["eeg"][:,0].restrict(data["rem"]) + +# We will also extract spike times from all units in our dataset +# which occur during REM sleep +spikes = data["units"].restrict(data["rem"]) + +# %% +# *** +# Plotting the LFP Activity +# ----------------------------------- +# We should first plot our REM Local Field Potential data. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) +ax.plot(tsd_rem.restrict(ep_ex_rem)) +ax.set_title("REM Local Field Potential") +ax.set_ylabel("LFP (a.u.)") +ax.set_xlabel("time (s)") + + + +# %% +# *** +# Getting the Wavelet Decomposition +# ----------------------------------- +# As we would expect, it looks like we have a very strong theta oscillation within our data +# - this is a common feature of REM sleep. Let's perform a wavelet decomposition, +# as we did in the last tutorial, to see get a more informative breakdown of the +# frequencies present in the data. +# +# We must define the frequency set that we'd like to use for our decomposition. + +freqs = np.geomspace(5, 200, 25) + +# %% +# We compute the wavelet transform on our LFP data (only during the example interval). + +cwt_rem = nap.compute_wavelet_transform(tsd_rem.restrict(ep_ex_rem), fs=FS, freqs=freqs) + +# %% +# *** +# Now let's plot the calculated wavelet scalogram. + + +# Define wavelet decomposition plotting function +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(np.abs(powers), aspect="auto") + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + ax.get_xaxis().set_visible(False) + ax.set(yticks=np.arange(len(freqs))[::2], yticklabels=np.rint(freqs[::2])) + ax.grid(False) + return im + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs, np.transpose(cwt_rem[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(tsd_rem.restrict(ep_ex_rem)) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + + +# %% +# *** +# Filtering Theta +# --------------- +# +# As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.apply_bandpass_filter`. + +theta_band = nap.apply_bandpass_filter(tsd_rem, cutoff=(6.0, 10.0), fs=FS) + +# %% +# We can plot the original signal and the filtered signal. + +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) +plt.plot(theta_band.restrict(ep_ex_rem)) +plt.xlabel("Time (s)") +plt.show() + + +# %% +# *** +# Computing phase +# --------------- +# +# From the filtered signal, it is easy to get the phase using the Hilbert transform. Here we use scipy Hilbert method. +from scipy import signal + +theta_phase = nap.Tsd(t=theta_band.t, d=np.angle(signal.hilbert(theta_band))) + +# %% +# Let's plot the phase. + +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.subplot(211) +plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) +plt.plot(theta_band.restrict(ep_ex_rem)) +plt.subplot(212) +plt.plot(theta_phase.restrict(ep_ex_rem), color='r') +plt.ylabel("Phase (rad)") +plt.xlabel("Time (s)") +plt.show() + + + +# %% +# *** +# Finding Phase of Spikes +# ----------------------- +# Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences +# of each of the units using the `compute_1d_tuning_curves` function. +# +# We will start by throwing away cells which do not have a high enough firing rate during our interval. +spikes = spikes[spikes.rate > 5.0] + +# %% +# The feature is the theta phase during REM sleep. + +phase_modulation = nap.compute_1d_tuning_curves( + group=spikes, feature=theta_phase, nb_bins=61, minmax=(-np.pi, np.pi) +) + +# %% +# Let's plot the first 3 neurons. + +plt.figure(constrained_layout=True, figsize = (12, 3)) +for i in range(3): + plt.subplot(1,3,i+1) + plt.plot(phase_modulation.iloc[:,i]) + plt.xlabel("Phase (rad)") + plt.ylabel("Firing rate (Hz)") +plt.show() + +# %% +# There is clearly a strong modulation for the third neuron. +# Finally, we can use the function `value_from` to align each spikes to the corresponding phase position and overlay +# it with the LFP. + +spike_phase = spikes[spikes.index[3]].value_from(theta_phase) + +# %% +# Let's plot it. +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.subplot(211) +plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) +plt.plot(theta_band.restrict(ep_ex_rem)) +plt.subplot(212) +plt.plot(theta_phase.restrict(ep_ex_rem), alpha=0.5) +plt.plot(spike_phase.restrict(ep_ex_rem), 'o') +plt.ylabel("Phase (rad)") +plt.xlabel("Time (s)") +plt.show() + + diff --git a/docs/examples/tutorial_wavelet_decomposition.py b/docs/examples/tutorial_wavelet_decomposition.py new file mode 100644 index 00000000..530540ec --- /dev/null +++ b/docs/examples/tutorial_wavelet_decomposition.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +""" +Wavelet Transform +============ +This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. +We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). + +Specifically, we will examine Local Field Potential data from a period of active traversal of a linear track. + +This tutorial was made by [Kipp Freud](https://kippfreud.com/). + +""" + + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# +# First, import the necessary libraries: +# +# mkdocs_gallery_thumbnail_number = 6 + +import math +import os + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import requests +import seaborn +import tqdm + +seaborn.set_theme() + +import pynapple as nap + +# %% +# *** +# Downloading the data +# ------------------ +# Let's download the data and save it locally + +path = "Achilles_10252013_EEG.nwb" +if path not in os.listdir("."): + r = requests.get(f"https://osf.io/2dfvp/download", stream=True) + block_size = 1024 * 1024 + with open(path, "wb") as f: + for data in tqdm.tqdm( + r.iter_content(block_size), + unit="MB", + unit_scale=True, + total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), + ): + f.write(data) +# Let's load and print the full dataset. +data = nap.load_file(path) +print(data) + + +# %% +# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. +# +# The `time_support` of the object `data['position']` contains the interval for which the rat was running along the linear track. We will call it `wake_ep`. +# + +FS = 1250 + +eeg = data["eeg"] + +wake_ep = data["position"].time_support + +# %% +# *** +# Selecting example +# ----------------------------------- +# We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, +# followed by 4 seconds of post-traversal activity. + +forward_ep = data["forward_ep"] +RUN_interval = nap.IntervalSet(forward_ep.start[7], forward_ep.end[7] + 4.0) + +eeg_example = eeg.restrict(RUN_interval)[:, 0] +pos_example = data["position"].restrict(RUN_interval) + +# %% +# *** +# Plotting the LFP and Behavioural Activity +# ----------------------------------- + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [["ephys"], ["pos"]], + height_ratios=[1, 0.4], +) +axd["ephys"].plot(eeg_example, label="CA1") +axd["ephys"].set_title("EEG (1250 Hz)") +axd["ephys"].set_ylabel("LFP (a.u.)") +axd["ephys"].set_xlabel("time (s)") +axd["ephys"].margins(0) +axd["ephys"].legend() +axd["pos"].plot(pos_example, color="black") +axd["pos"].margins(0) +axd["pos"].set_xlabel("time (s)") +axd["pos"].set_ylabel("Linearized Position") +axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) + +# %% +# In the top panel, we can see the lfp trace as a function of time, and on the bottom the mouse position on the linear +# track as a function of time. Position 0 and 1 correspond to the start and end of the trial respectively. + +# %% +# *** +# Getting the LFP Spectrogram +# ----------------------------------- +# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies during exploration (`wake_ep`). + + +power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep, norm=True) + +print(power) + +# %% +# *** +# The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. +# +# Let's plot the power between 1 and 100 Hz. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot( + np.abs(power[(power.index >= 1.0) & (power.index <= 100)]), + alpha=0.5, + label="LFP Frequency Power", +) +ax.axvspan(6, 10, color="red", alpha=0.1) +ax.set_xlabel("Freq (Hz)") +ax.set_ylabel("Frequency Power") +ax.set_title("LFP Fourier Decomposition") +ax.legend() + + +# %% +# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. +# Hippocampal theta rhythm appears mostly when the animal is running [1]. +# We can check it here by separating the wake epochs (`wake_ep`) into run epochs (`run_ep`) and rest epochs (`rest_ep`). + +# The run epoch is the portion of the data for which we have position data +run_ep = data["position"].dropna().find_support(1) +# The rest epoch is the data at all points where we do not have position data +rest_ep = wake_ep.set_diff(run_ep) + +# %% +# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# +# The function `nap.compute_power_spectral_density` takes signal with a single epoch to avoid artefacts between epochs jumps. +# +# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. +# +# In this case, `interval_size` is equal to 1.5 seconds. + +power_run = nap.compute_mean_power_spectral_density( + eeg, 1.5, fs=FS, ep=run_ep, norm=True +) +power_rest = nap.compute_mean_power_spectral_density( + eeg, 1.5, fs=FS, ep=rest_ep, norm=True +) + +# %% +# `power_run` and `power_rest` are the power spectral density when the animal is respectively running and resting. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot( + np.abs(power_run[(power_run.index >= 3.0) & (power_run.index <= 30)]), + alpha=1, + label="Run", + linewidth=2, +) +ax.plot( + np.abs(power_rest[(power_rest.index >= 3.0) & (power_rest.index <= 30)]), + alpha=1, + label="Rest", + linewidth=2, +) +ax.axvspan(6, 10, color="red", alpha=0.1) +ax.set_xlabel("Freq (Hz)") +ax.set_ylabel("Frequency Power") +ax.set_title("LFP Fourier Decomposition") +ax.legend() + + +# %% +# *** +# Getting the Wavelet Decomposition +# ----------------------------------- +# Overall, the prominent frequencies in the data vary over time. The LFP characteristics may be different when the animal is running along the track, and when it is finished. +# Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. + +# We must define the frequency set that we'd like to use for our decomposition +freqs = np.geomspace(3, 250, 100) + +# %% +# Compute and print the wavelet transform on our LFP data + +mwt_RUN = nap.compute_wavelet_transform(eeg_example, fs=FS, freqs=freqs) + + +# %% +# `mwt_RUN` is a TsdFrame with each column being the convolution with one wavelet at a particular frequency. +print(mwt_RUN) + +# %% +# *** +# Now let's plot it. + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(3, 1, figure=fig, height_ratios=[1.0, 0.5, 0.1]) + +ax0 = plt.subplot(gs[0, 0]) +pcmesh = ax0.pcolormesh(mwt_RUN.t, freqs, np.transpose(np.abs(mwt_RUN))) +ax0.grid(False) +ax0.set_yscale("log") +ax0.set_title("Wavelet Decomposition") +ax0.set_ylabel("Frequency (Hz)") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") +ax0.set_ylabel("Amplitude") + +ax1 = plt.subplot(gs[1, 0], sharex=ax0) +ax1.plot(eeg_example) +ax1.set_ylabel("LFP (a.u.)") + +ax1 = plt.subplot(gs[2, 0], sharex=ax0) +ax1.plot(pos_example, color="black") +ax1.set_xlabel("Time (s)") +ax1.set_ylabel("Pos.") + + +# %% +# *** +# Visualizing Theta Band Power +# ----------------------------------- +# There seems to be a strong theta frequency present in the data during the maze traversal. +# Let's plot the estimated 6-10Hz component of the wavelet decomposition on top of our data, and see how well they match up. + +theta_freq_index = np.logical_and(freqs > 6, freqs < 10) + + +# Extract its real component, as well as its power envelope +theta_band_reconstruction = np.mean(mwt_RUN[:, theta_freq_index], 1) +theta_band_power_envelope = np.abs(theta_band_reconstruction) + + +# %% +# *** +# Now let's visualise the theta band component of the signal over time. + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.9]) +ax0 = plt.subplot(gs[0, 0]) +ax0.plot(eeg_example, label="CA1") +ax0.set_title("EEG (1250 Hz)") +ax0.set_ylabel("LFP (a.u.)") +ax0.set_xlabel("time (s)") +ax0.legend() +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(np.real(theta_band_reconstruction), label="6-10 Hz oscillations") +ax1.plot(theta_band_power_envelope, label="6-10 Hz power envelope") +ax1.set_xlabel("time (s)") +ax1.set_ylabel("Wavelet transform") +ax1.legend() + +# %% +# *** +# We observe that the theta power is far stronger during the first 4 seconds of the dataset, during which the rat +# is traversing the linear track. + +# %% +# *** +# Visualizing High Frequency Oscillation +# ----------------------------------- +# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. Here we use the interval (18356, 18357.5) seconds to zoom in. + +zoom_ep = nap.IntervalSet(18356.0, 18357.5) + +mwt_zoom = mwt_RUN.restrict(zoom_ep) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0, 0]) +pcmesh = ax0.pcolormesh(mwt_zoom.t, freqs, np.transpose(np.abs(mwt_zoom))) +ax0.grid(False) +ax0.set_yscale("log") +ax0.set_title("Wavelet Decomposition") +ax0.set_ylabel("Frequency (Hz)") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") +ax0.set_label("Amplitude") + +ax1 = plt.subplot(gs[1, 0], sharex=ax0) +ax1.plot(eeg_example.restrict(zoom_ep)) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") + +# %% +# Those events are called Sharp-waves ripples [2]. +# +# Among other methods, we can use the Wavelet decomposition to isolate them. In this case, we will look at the power of the wavelets for frequencies between 150 to 250 Hz. + +ripple_freq_index = np.logical_and(freqs > 150, freqs < 250) + +# %% +# We can compute the mean power for this frequency band. + +ripple_power = np.mean(np.abs(mwt_RUN[:, ripple_freq_index]), 1) + + +# %% +# Now let's visualise the 150-250 Hz mean amplitude of the wavelet decomposition over time + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0, 0]) +ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") +ax0.set_ylabel("LFP (a.u.)") +ax0.set_title(f"EEG (1250 Hz)") +ax1 = plt.subplot(gs[1, 0]) +ax1.legend() +ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") +ax1.legend() +ax1.set_ylabel("Mean Amplitude") +ax1.set_xlabel("Time (s)") + + +# %% +# It is then easy to isolate ripple times by using the pynapple functions `smooth` and `threshold`. In the following lines, `ripples` is smoothed with a gaussian kernel of size 0.005 second and thesholded with a value of 100. +# + +smoothed_ripple_power = ripple_power.smooth(0.005) + +threshold_ripple_power = smoothed_ripple_power.threshold(100) + +# %% +# `threshold_ripple_power` contains all the time points above 100. The ripple epochs are contained in the `time_support` of the threshold time series. Here we call it `rip_ep`. + +rip_ep = threshold_ripple_power.time_support + + +# %% +# Now let's plot the ripples epoch as well as the smoothed ripple power. +# +# We can also plot `rip_ep` as vertical boxes to see if the detection is accurate + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0, 0]) +ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax0.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) +ax0.set_ylabel("LFP (a.u.)") +ax0.set_title(f"EEG (1250 Hz)") +ax1 = plt.subplot(gs[1, 0]) +ax1.legend() +ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") +ax1.plot(smoothed_ripple_power.restrict(zoom_ep)) +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax1.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) +ax1.legend() +ax1.set_ylabel("Mean Amplitude") +ax1.set_xlabel("Time (s)") + + +# %% +# Finally, let's zoom in on each of our isolated ripples + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 2, figure=fig, height_ratios=[1.0, 1.0]) +buffer = 0.075 +plt.suptitle("Isolated Sharp Wave Ripples") +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax = plt.subplot(gs[int(i / 2), i % 2]) + ax.plot(eeg_example.restrict(nap.IntervalSet(s - buffer, e + buffer))) + ax.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) + ax.set_xlim(s - buffer, e + buffer) + ax.set_xlabel("Time (s)") + ax.set_ylabel("LFP (a.u.)") + + +# %% +# *** +# References +# ----------------------------------- +# +# [1] Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666. +# +# [2] Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188. diff --git a/docs/index.md b/docs/index.md index 0df91139..d517eeb1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,6 +30,15 @@ To ask any questions or get support for using pynapple, please consider joining New releases :fire: ------------------ +### pynapple >= 0.7 + +Pynapple now implements signal processing. For example, to filter a 1250 Hz sampled time series between 10 Hz and 20 Hz: + +```python +nap.apply_bandpass_filter(signal, (10, 20), fs=1250) +``` +New functions includes power spectral density and Morlet wavelet decomposition. See the [documentation](https://pynapple-org.github.io/pynapple/reference/process/) for more details. + ### pynapple >= 0.6 Starting with 0.6, [`IntervalSet`](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) objects are behaving as immutable numpy ndarray. Before 0.6, you could select an interval within an `IntervalSet` object with: @@ -44,8 +53,6 @@ With pynapple>=0.6, the slicing is similar to numpy and it returns an `IntervalS new_intervalset = intervalset[0] ``` -See the [documentation](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) for more details. - ### pynapple >= 0.4 Starting with 0.4, pynapple rely on the [numpy array container](https://numpy.org/doc/stable/user/basics.dispatch.html) approach instead of Pandas for the time series. Pynapple builtin functions will remain the same except for functions inherited from Pandas. @@ -54,7 +61,6 @@ This allows for a better handling of returned objects. Additionaly, it is now possible to define time series objects with more than 2 dimensions with `TsdTensor`. You can also look at this [notebook](https://pynapple-org.github.io/pynapple/generated/gallery/tutorial_pynapple_numpy/) for a demonstration of numpy compatibilities. - Getting Started --------------- diff --git a/draft_pynapple_fastplotlib.py b/draft_pynapple_fastplotlib.py deleted file mode 100644 index a3dd423d..00000000 --- a/draft_pynapple_fastplotlib.py +++ /dev/null @@ -1,172 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Fastplotlib -=========== - -Working with calcium data. - -For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. - -The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it. - -See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. - -This tutorial was made by Sofia Skromne Carrasco and Guillaume Viejo. - -""" -# %% -# %gui qt - -import pynapple as nap -import numpy as np -import fastplotlib as fpl - -import sys -# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' - -def get_memory_map(filepath, nChannels, frequency=20000): - n_channels = int(nChannels) - f = open(filepath, 'rb') - startoffile = f.seek(0, 0) - endoffile = f.seek(0, 2) - bytes_size = 2 - n_samples = int((endoffile-startoffile)/n_channels/bytes_size) - duration = n_samples/frequency - interval = 1/frequency - f.close() - fp = np.memmap(filepath, np.int16, 'r', shape = (n_samples, n_channels)) - timestep = np.arange(0, n_samples)/frequency - - return fp, timestep - - -#### LFP -data_array, time_array = get_memory_map("your/path/to/MyProject/sub-A2929/A2929-200711/A2929-200711.dat", 16) -lfp = nap.TsdFrame(t=time_array, d=data_array) - -lfp2 = lfp.get(0, 20)[:,14] -lfp2 = np.vstack((lfp2.t, lfp2.d)).T - -#### NWB -nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") -units = nwb['units']#.getby_category("location")['adn'] -tmp = units.to_tsd().get(0, 20) -tmp = np.vstack((tmp.index.values, tmp.values)).T - - - -fig = fpl.Figure(canvas="glfw", shape=(2,1)) -fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn") -fig[1,0].add_scatter(tmp) -fig.show(maintain_aspect=False) -# fpl.run() - - - - -# grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = ['lfp', 'wavelet']) -# grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d) - - -import numpy as np -import fastplotlib as fpl - -fig = fpl.Figure(canvas="glfw")#, shape=(2,1), controller_ids="sync") -fig[0,0].add_line(data=np.random.randn(1000)) -fig.show(maintain_aspect=False) - -fig2 = fpl.Figure(canvas="glfw", controllers=fig.controllers)#, shape=(2,1), controller_ids="sync") -fig2[0,0].add_line(data=np.random.randn(1000)*1000) -fig2.show(maintain_aspect=False) - - - -# Not sure about this : -fig[1,0].controller.controls["mouse1"] = "pan", "drag", (1.0, 0.0) - -fig[1,0].controller.controls.pop("mouse2") -fig[1,0].controller.controls.pop("mouse4") -fig[1,0].controller.controls.pop("wheel") - -import pygfx - -controller = pygfx.PanZoomController() -controller.controls.pop("mouse1") -controller.add_camera(fig[0, 0].camera) -controller.register_events(fig[0, 0].viewport) - -controller2 = pygfx.PanZoomController() -controller2.add_camera(fig[1, 0].camera) -controller2.controls.pop("mouse1") -controller2.register_events(fig[1, 0].viewport) - - - - - - - - - - - - - - - - -sys.exit() - -################################################################################################# - - -nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") -units = nwb['units']#.getby_category("location")['adn'] -tmp = units.to_tsd() -tmp = np.vstack((tmp.index.values, tmp.values)).T - -# Example 1 - -fplot = fpl.Plot() -fplot.add_scatter(tmp) -fplot.graphics[0].cmap = "jet" -fplot.graphics[0].cmap.values = tmp[:, 1] -fplot.show(maintain_aspect=False) - -# Example 2 - -names = [['raster'], ['position']] -grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names) -grid_plot['raster'].add_scatter(tmp) -grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T) -grid_plot.show(maintain_aspect=False) -grid_plot['raster'].auto_scale(maintain_aspect=False) - - -# Example 3 -#frames = iio.imread("/Users/gviejo/pynapple/A0670-221213_filtered.avi") -#frames = frames[:,:,:,0] -frames = np.random.randn(10, 100, 100) - -iw = fpl.ImageWidget(frames, cmap="gnuplot2") - -#iw.show() - -# Example 4 - -from PyQt6 import QtWidgets - - -mainwidget = QtWidgets.QWidget() - -hlayout = QtWidgets.QHBoxLayout(mainwidget) - -iw.widget.setParent(mainwidget) - -hlayout.addWidget(iw.widget) - -grid_plot.widget.setParent(mainwidget) - -hlayout.addWidget(grid_plot.widget) - -mainwidget.show() diff --git a/mkdocs.yml b/mkdocs.yml index 063d2549..b2e35d8d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,4 +65,5 @@ markdown_extensions: pygments_lang_class: true - pymdownx.inlinehilite - pymdownx.snippets - - pymdownx.superfences \ No newline at end of file + - pymdownx.superfences + - admonition diff --git a/pynapple/__init__.py b/pynapple/__init__.py index 06b74c00..f55f6194 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.6" +__version__ = "0.7.0" from .core import ( IntervalSet, Ts, diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index a7f67d4d..33cb2a0a 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -11,7 +11,7 @@ import numpy as np from scipy import signal -from ._jitted_functions import ( +from ._jitted_functions import ( # pjitconvolve, jitbin_array, jitcount, jitremove_nan, @@ -19,7 +19,6 @@ jitrestrict_with_count, jitthreshold, jitvaluefrom, - pjitconvolve, ) from .utils import get_backend @@ -28,13 +27,13 @@ def _restrict(time_array, starts, ends): return jitrestrict(time_array, starts, ends) -def _count(time_array, starts, ends, bin_size=None): +def _count(time_array, starts, ends, bin_size=None, dtype=None): if isinstance(bin_size, (float, int)): - return jitcount(time_array, starts, ends, bin_size) + t, d = jitcount(time_array, starts, ends, bin_size, dtype) else: - _, d = jitrestrict_with_count(time_array, starts, ends) + _, d = jitrestrict_with_count(time_array, starts, ends, dtype) t = starts + (ends - starts) / 2 - return t, d + return t, d def _value_from(time_array, time_target_array, data_target_array, starts, ends): @@ -99,36 +98,37 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): return convolve(time_array, data_array, starts, ends, array, trim) else: - if data_array.ndim == 1: - new_data_array = np.zeros(data_array.shape) - k = array.shape[0] - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - - t = idx_e - idx_s - if trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) - else: - cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) - # scipy is actually faster for Tsd - new_data_array[idx_s:idx_e] = signal.convolve( - data_array[idx_s:idx_e], array - )[cut[0] : cut[1]] - - return new_data_array - else: - new_data_array = np.zeros(data_array.shape) - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - new_data_array[idx_s:idx_e] = pjitconvolve( - data_array[idx_s:idx_e], array, trim=trim - ) - - return new_data_array + # reshape to 2d + shape = data_array.shape + data_array = np.reshape(data_array, (shape[0], -1)) + + kshape = array.shape + k = kshape[0] + array = array.reshape(k, -1) + + new_data_array = np.zeros((shape[0], int(np.prod(shape[1:])), *array.shape[1:])) + + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + + t = idx_e - idx_s + if trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) + else: + cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) + + for i in range(data_array.shape[1]): + for j in range(array.shape[1]): + new_data_array[idx_s:idx_e, i, j] = signal.convolve( + data_array[idx_s:idx_e, i], array[:, j] + )[cut[0] : cut[1]] + + new_data_array = new_data_array.reshape((*shape, *kshape[1:])) + + return new_data_array def _bin_average(time_array, data_array, starts, ends, bin_size): diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index 669343c6..9269a245 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -1,5 +1,5 @@ import numpy as np -from numba import jit, njit, prange +from numba import jit # , njit, prange ################################ @@ -44,11 +44,11 @@ def jitrestrict(time_array, starts, ends): @jit(nopython=True) -def jitrestrict_with_count(time_array, starts, ends): +def jitrestrict_with_count(time_array, starts, ends, dtype=np.int64): n = len(time_array) m = len(starts) ix = np.zeros(n, dtype=np.int64) - count = np.zeros(m, dtype=np.int64) + count = np.zeros(m, dtype=dtype) k = 0 t = 0 @@ -118,7 +118,7 @@ def jitvaluefrom(time_array, time_target_array, count, count_target, starts, end @jit(nopython=True) -def jitcount(time_array, starts, ends, bin_size): +def jitcount(time_array, starts, ends, bin_size, dtype): idx, countin = jitrestrict_with_count(time_array, starts, ends) time_array = time_array[idx] @@ -133,7 +133,7 @@ def jitcount(time_array, starts, ends, bin_size): nb = np.sum(nb_bins) bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros(nb, dtype=np.int64) + cnt = np.zeros(nb, dtype=dtype) k = 0 t = 0 @@ -322,7 +322,6 @@ def jitbin_array(time_array, data_array, starts, ends, bin_size): @jit(nopython=True) def _jitbin_array(countin, time_array, data_array, starts, ends, bin_size): - m = starts.shape[0] f = data_array.shape[1:] @@ -375,33 +374,33 @@ def _jitbin_array(countin, time_array, data_array, starts, ends, bin_size): return (new_time_array, new_data_array) -@jit(nopython=True) -def jitconvolve(d, a): - return np.convolve(d, a) +# @jit(nopython=True) +# def jitconvolve(d, a): +# return np.convolve(d, a) -@njit(parallel=True) -def pjitconvolve(data_array, array, trim="both"): - shape = data_array.shape - t = shape[0] - k = array.shape[0] +# @njit(parallel=True) +# def pjitconvolve(data_array, array, trim="both"): +# shape = data_array.shape +# t = shape[0] +# k = array.shape[0] - data_array = data_array.reshape(t, -1) - new_data_array = np.zeros(data_array.shape) +# data_array = data_array.reshape(t, -1) +# new_data_array = np.zeros(data_array.shape) - if trim == "both": - cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) - elif trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) +# if trim == "both": +# cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) +# elif trim == "left": +# cut = (k - 1, t + k - 1) +# elif trim == "right": +# cut = (0, t) - for i in prange(data_array.shape[1]): - new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] +# for i in prange(data_array.shape[1]): +# new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] - new_data_array = new_data_array.reshape(shape) +# new_data_array = new_data_array.reshape(shape) - return new_data_array +# return new_data_array ################################ diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index da8c91ce..8436c222 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -11,7 +11,7 @@ from ._core_functions import _count, _restrict, _value_from from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import convert_to_numpy_array +from .utils import check_filename, convert_to_numpy_array class Base(abc.ABC): @@ -23,7 +23,6 @@ class Base(abc.ABC): _initialized = False def __init__(self, t, time_units="s", time_support=None): - if isinstance(t, TsIndex): self.index = t else: @@ -45,7 +44,7 @@ def __init__(self, t, time_units="s", time_support=None): self.time_support.values[:, 1] - self.time_support.values[:, 0] ) else: - self.rate = np.NaN + self.rate = np.nan self.time_support = IntervalSet(start=[], end=[]) @property @@ -204,7 +203,7 @@ def value_from(self, data, ep=None): return t, d, time_support, kwargs - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -232,6 +231,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -291,6 +292,14 @@ def count(self, *args, **kwargs): if isinstance(a, IntervalSet): ep = a + if dtype is None: + dtype = np.dtype(np.int64) + else: + try: + dtype = np.dtype(dtype) + except Exception: + raise ValueError(f"{dtype} is not a valid numpy dtype.") + starts = ep.start ends = ep.end @@ -299,7 +308,7 @@ def count(self, *args, **kwargs): time_array = self.index.values - t, d = _count(time_array, starts, ends, bin_size) + t, d = _count(time_array, starts, ends, bin_size, dtype=dtype) return t, d, ep @@ -335,6 +344,7 @@ def restrict(self, iset): 0 0.0 500.0 """ + assert isinstance(iset, IntervalSet), "Argument should be IntervalSet" time_array = self.index.values @@ -403,47 +413,254 @@ def get(self, start, end=None, time_units="s"): end : float or int or None The end """ - assert isinstance(start, Number), "start should be a float or int" - time_array = self.index.values + sl = self.get_slice(start, end, time_units) if end is None: - start = TsIndex.format_timestamps(np.array([start]), time_units)[0] - idx = int(np.searchsorted(time_array, start)) - if idx == 0: - return self[idx] - elif idx >= self.shape[0]: - return self[-1] - else: - if start - time_array[idx - 1] < time_array[idx] - start: - return self[idx - 1] - else: - return self[idx] + sl = sl.start + + return self[sl] + + def get_slice(self, start, end=None, time_unit="s"): + """ + Get a slice object from the time series data based on the start and end values such that all the timestamps satisfy `start<=t<=end`. + If `end` is None, only the timepoint closest to `start` is returned. + + By default, the time support doesn't change. If you want to change the time support, use the `restrict` function. + + This function is equivalent of calling the `get` method. + + Parameters + ---------- + start : int or float + The starting value for the slice. + end : int or float, optional + The ending value for the slice. Defaults to None. + time_unit : str, optional + The time unit for the start and end values. Defaults to "s" (seconds). + + Returns + ------- + slice : slice + A slice determining the start and end indices, with unit step + Slicing the array will be equivalent to calling get: `ts[s].t == ts.get(start, end).t` with `s` being the slice object. + + + Raises + ------ + ValueError + - If start or end is not a number. + - If start is greater than end. + + Examples + -------- + >>> import pynapple as nap + + >>> ts = nap.Ts(t = [0, 1, 2, 3]) + + >>> # slice over a range + >>> start, end = 1.2, 2.6 + >>> print(ts.get_slice(start, end)) # returns `slice(2, 3, None)` + >>> start, end = 1., 2. + >>> print(ts.get_slice(start, end, mode="forward")) # returns `slice(1, 3, None)` + + >>> # slice a single value + >>> start = 1.2 + >>> print(ts.get_slice(start)) # returns `slice(1, 2, None)` + >>> start = 2. + >>> print(ts.get_slice(start)) # returns `slice(2, 3, None)` + """ + mode = "closest_t" if end is None else "restrict" + return self._get_slice( + start, end=end, mode=mode, n_points=None, time_unit=time_unit + ) + + def _get_slice( + self, start, end=None, mode="closest_t", n_points=None, time_unit="s" + ): + """ + Get a slice from the time series data based on the start and end values with the specified mode. + + For a given time t, mode `before_t` means you want the timepoint right before t to start the slice. + Mode `after_t` means you want the timepoint right after t to start the slice. + + Parameters + ---------- + start : int or float + The starting value for the slice. + end : int or float, optional + The ending value for the slice. Defaults to None. + mode : str, optional + The mode for slicing. Can be "after_t", "before_t", "restrict", or "closest_t". Defaults to "closest_t". + time_unit : str, optional + The time unit for the start and end values. Defaults to "s" (seconds). + n_points : int, optional + Number of time point that will result from applying the slice. This parameter is used to + calculate a step size for the slice. + + Returns + ------- + slice : slice + If end is not provided: + - For mode == "before_t": + - An empty slice for start < self.t[0] + - slice(idx, idx+1) with self.t[idx] <= start < self.t[idx+1] + - For mode == "after_t": + - An empty slice for start >= self.t[-1] + - slice(idx, idx+1) with self.t[idx-1] < start <= self.t[idx] + - For mode == "closest_t": + - slice(idx, idx+1) with the closest index to start + - For mode == "restrict": + - slice the indices such that start <= self.t[idx] <= end + If end is provided: + - For mode == "before_t": + - An empty slice if end < self.t[0] + - slice(idx_start, idx_end) with self.t[idx_start] <= start < self.t[idx_start+1] and + self.t[idx_end] <= end < self.t[idx_end+1] + - For mode == "after_t": + - An empty slice if start > self.t[-1] + - slice(idx_start, idx_end) with self.t[idx_start-1] <= start < self.t[idx_start] and + self.t[idx_end-1] <= end < self.t[idx_end] + - For mode == "closest": + - slice(idx_start, idx_end) with the closest indices to start and end + - For mode == "restrict": + - An empty slice if start > self.t[-1] or end < self.t[0] + - slice(idx_start, idx_end) with self.t[idx_start] <= start <= self.t[idx_start+1] and + self.t[idx_end] <= end <= self.t[idx_end+1] + + Raises + ------ + ValueError + - If start or end is not a number. + - If start is greater than end. + + """ + if not isinstance(start, Number): + raise ValueError( + f"'start' must be an int or a float. Type {type(start)} provided instead!" + ) + + if n_points is not None and not isinstance(n_points, int): + raise TypeError( + f"'n_points' must be of type int or None. Type {type(n_points)} provided instead!" + ) + + if end is None and n_points: + raise ValueError("'n_points' can be used only when 'end' is specified!") + + if mode not in ["before_t", "after_t", "closest_t", "restrict"]: + raise ValueError( + "'mode' only accepts 'before_t', 'after_t', 'closest_t' or 'restrict'." + ) + + if mode == "restrict" and n_points: + raise ValueError( + "Fixing the number of time points is incompatible with 'restrict' mode." + ) + + # convert and get index for start + start = TsIndex.format_timestamps(np.array([start]), time_unit)[0] + + # check end + if end is not None and not isinstance(end, Number): + raise ValueError( + f"'end' must be an int or a float. Type {type(end)} provided instead!" + ) + + # get index of preceding time value + idx_start = np.searchsorted(self.t, start, side="left") + if idx_start == len(self.t) and mode != "restrict": + idx_start -= 1 # make sure the index is not out of bound + + if mode == "before_t": + # in order to get the index preceding start + # subtract one except if self.t[idx_start] is exactly equal to start + idx_start -= self.t[idx_start] > start + elif mode == "closest_t": + # subtract 1 if start is closer to the previous index + di = self.t[idx_start] - start > np.abs(self.t[idx_start - 1] - start) + idx_start -= di + + if end is None: + if idx_start < 0: # happens only on backwards if start < self.t[0] + return slice(0, 0) + elif ( + idx_start == len(self.t) - 1 and mode == "after_t" + ): # happens only on forward if start >= self.t[-1] + return slice(idx_start, idx_start) + return slice(idx_start, idx_start + 1) else: - assert isinstance(end, Number), "end should be a float or int" - assert start < end, "Start should not precede end" - start, end = TsIndex.format_timestamps(np.array([start, end]), time_units) - idx_start = np.searchsorted(time_array, start) - idx_end = np.searchsorted(time_array, end, side="right") - return self[idx_start:idx_end] - - # def find_gaps(self, min_gap, time_units='s'): - # """ - # finds gaps in a tsd larger than min_gap. Return an IntervalSet. - # Epochs are defined by adding and removing 1 microsecond to the time index. - - # Parameters - # ---------- - # min_gap : float - # The minimum interval size considered to be a gap (default is second). - # time_units : str, optional - # Time units of min_gap ('us', 'ms', 's' [default]) - # """ - # min_gap = format_timestamps(np.array([min_gap]), time_units)[0] - - # time_array = self.index - # starts = self.time_support.start - # ends = self.time_support.end - - # s, e = jitfind_gaps(time_array, starts, ends, min_gap) - - # return nap.IntervalSet(s, e) + idx_start = max([0, idx_start]) # if taking a range set slice index to 0 + + # convert and get index for end + end = TsIndex.format_timestamps(np.array([end]), time_unit)[0] + if start > end: + raise ValueError("'start' should not precede 'end'.") + + idx_end = np.searchsorted(self.t, end, side="left") + add_if_forward = 0 + if idx_end == len(self.t): + idx_end -= 1 # make sure the index is not out of bound + add_if_forward = 1 # add back the index if forward + + if mode == "before_t": + # remove 1 if self.t[idx_end] is larger than end, except if idx_end is 0 + idx_end -= (self.t[idx_end] > end) - int(idx_end == 0) + elif mode == "closest_t": + # subtract 1 if end is closer to self.t[idx_end - 1] + di = self.t[idx_end] - end > np.abs(self.t[idx_end - 1] - end) + idx_end -= di + elif mode == "after_t" and idx_end == len(self.t) - 1: + idx_end += add_if_forward # add one if idx_start < len(self.t) + elif mode == "restrict": + idx_end += int(self.t[idx_end] <= end) + + step = None + if n_points: + tot_tps = idx_end - idx_start + if tot_tps > n_points: + rounding = tot_tps % n_points + step = tot_tps // n_points + idx_end -= rounding + + return slice(idx_start, idx_end, step) + + def _get_filename(self, filename): + """Check if the filename is valid and return the path + + Parameters + ---------- + filename : str or Path + The filename + + Returns + ------- + Path + The path to the file + + Raises + ------ + RuntimeError + If the filename is a directory or the parent does not exist + """ + + return check_filename(filename) + + @classmethod + def _from_npz_reader(cls, file): + """Load a time series object from a npz file interface. + + Parameters + ---------- + file : NPZFile object + opened npz file interface. + + Returns + ------- + out : Ts or Tsd or TsdFrame or TsdTensor + The time series object + """ + kwargs = { + key: file[key] for key in file.keys() if key not in ["start", "end", "type"] + } + iset = IntervalSet(start=file["start"], end=file["end"]) + return cls(time_support=iset, **kwargs) diff --git a/pynapple/core/config.py b/pynapple/core/config.py index 97eaafaa..fbc59ebb 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -98,7 +98,6 @@ def backend(self, backend): self.set_backend(backend) def set_backend(self, backend): - assert backend in ["numba", "jax"], "Options for backend are 'jax' or 'numba'" # Try to import pynajax diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 3f65f802..ef74b18b 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -40,7 +40,6 @@ """ import importlib -import os import warnings from numbers import Number @@ -61,6 +60,7 @@ from .utils import ( _get_terminal_size, _IntervalSetSliceHelper, + check_filename, convert_to_numpy_array, is_array_like, ) @@ -80,7 +80,7 @@ class IntervalSet(NDArrayOperatorsMixin): A class representing a (irregular) set of time intervals in elapsed time, with relative operations """ - def __init__(self, start, end=None, time_units="s", **kwargs): + def __init__(self, start, end=None, time_units="s"): """ IntervalSet initializer @@ -94,8 +94,13 @@ def __init__(self, start, end=None, time_units="s", **kwargs): Parameters ---------- - start : numpy.ndarray or number or pandas.DataFrame or pandas.Series - Beginning of intervals + start : numpy.ndarray or number or pandas.DataFrame or pandas.Series or iterable of (start, end) pairs + Beginning of intervals. Alternatively, the `end` argument can be left out and `start` can be one of the + following: + - IntervalSet + - pandas.DataFrame with columns ["start", "end"] + - iterable of (start, end) pairs + - a single (start, end) pair end : numpy.ndarray or number or pandas.Series, optional Ends of intervals time_units : str, optional @@ -108,8 +113,8 @@ def __init__(self, start, end=None, time_units="s", **kwargs): """ if isinstance(start, IntervalSet): - end = start.values[:, 1].astype(np.float64) - start = start.values[:, 0].astype(np.float64) + end = start.end.astype(np.float64) + start = start.start.astype(np.float64) elif isinstance(start, pd.DataFrame): assert ( @@ -125,7 +130,15 @@ def __init__(self, start, end=None, time_units="s", **kwargs): start = start["start"].values.astype(np.float64) else: - assert end is not None, "Missing end argument when initializing IntervalSet" + if end is None: + # Require iterable of (start, end) tuples + try: + start_end_array = np.array(list(start)).reshape(-1, 2) + start, end = zip(*start_end_array) + except (TypeError, ValueError): + raise ValueError( + "Unable to Interpret the input. Please provide a list of start-end pairs." + ) args = {"start": start, "end": end} @@ -228,6 +241,9 @@ def __str__(self): def __len__(self): return len(self.values) + # def __iter__(self): + # pass + def __setitem__(self, key, value): raise RuntimeError( "IntervalSet is immutable. Starts and ends have been already sorted." @@ -332,8 +348,13 @@ def starts(self): Ts The starts of the IntervalSet """ + warnings.warn( + "starts is a deprecated function. It will be removed in future versions", + category=DeprecationWarning, + stacklevel=2, + ) time_series = importlib.import_module(".time_series", "pynapple.core") - return time_series.Ts(t=self.values[:, 0], time_support=self) + return time_series.Ts(t=self.values[:, 0]) @property def ends(self): @@ -344,8 +365,13 @@ def ends(self): Ts The ends of the IntervalSet """ + warnings.warn( + "ends is a deprecated function. It will be removed in future versions", + category=DeprecationWarning, + stacklevel=2, + ) time_series = importlib.import_module(".time_series", "pynapple.core") - return time_series.Ts(t=self.values[:, 1], time_support=self) + return time_series.Ts(t=self.values[:, 1]) @property def loc(self): @@ -354,6 +380,25 @@ def loc(self): """ return _IntervalSetSliceHelper(self) + @classmethod + def _from_npz_reader(cls, file): + """Load an IntervalSet object from a npz file. + + The file should contain the keys 'start', 'end' and 'type'. + The 'type' key should be 'IntervalSet'. + + Parameters + ---------- + file : NPZFile object + opened npz file interface. + + Returns + ------- + IntervalSet + The IntervalSet object + """ + return cls(start=file["start"], end=file["end"]) + def time_span(self): """ Time span of the interval set. @@ -643,29 +688,87 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) - np.savez( - filename, + check_filename(filename), start=self.values[:, 0], end=self.values[:, 1], type=np.array(["IntervalSet"], dtype=np.str_), ) return + + def split(self, interval_size, time_units="s"): + """Split `IntervalSet` to a new `IntervalSet` with each interval being of size `interval_size`. + + Used mostly for chunking very large dataset or looping throught multiple epoch of same duration. + + This function skips the epochs that are shorter than `interval_size`. + + Note that intervals are strictly non-overlapping in pynapple. One microsecond is removed from contiguous intervals. + + Parameters + ---------- + interval_size : Number + Description + time_units : str, optional + time units for the `interval_size` ('us', 'ms', 's' [default]) + + Returns + ------- + IntervalSet + New `IntervalSet` with equal sized intervals + + Raises + ------ + IOError + If `interval_size` is not a Number or is below 0 + If `time_units` is not a string + """ + if not isinstance(interval_size, Number): + raise IOError("Argument interval_size should of type float or int") + + if not interval_size > 0: + raise IOError("Argument interval_size should be strictly larger than 0") + + if not isinstance(time_units, str): + raise IOError("Argument time_units should be of type str") + + if len(self) == 0: + return IntervalSet(start=[], end=[]) + + interval_size = TsIndex.format_timestamps( + np.array((interval_size,), dtype=np.float64).ravel(), time_units + )[0] + + interval_size = np.round(interval_size, nap_config.time_index_precision) + + durations = np.round(self.end - self.start, nap_config.time_index_precision) + + idxs = np.where(durations > interval_size)[0] + size_tmp = ( + np.ceil((self.end[idxs] - self.start[idxs]) / interval_size) + ).astype(int) + 1 + new_starts = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan) + new_ends = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan) + i0 = 0 + for cnt, idx in enumerate(idxs): + new_starts[i0 : i0 + size_tmp[cnt] - 1] = np.arange( + self.start[idx], self.end[idx], interval_size + ) + new_ends[i0 : i0 + size_tmp[cnt] - 2] = new_starts[ + i0 + 1 : i0 + size_tmp[cnt] - 1 + ] + new_ends[i0 + size_tmp[cnt] - 2] = self.end[idx] + i0 += size_tmp[cnt] - 1 + new_starts = np.round(new_starts, nap_config.time_index_precision) + new_ends = np.round(new_ends, nap_config.time_index_precision) + + durations = np.round(new_ends - new_starts, nap_config.time_index_precision) + tokeep = durations >= interval_size + new_starts = new_starts[tokeep] + new_ends = new_ends[tokeep] + + # Removing 1 microsecond to have strictly non-overlapping intervals for intervals coming from the same epoch + new_ends -= 1e-6 + + return IntervalSet(new_starts, new_ends) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 95ae2c37..2af7f269 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -17,7 +17,6 @@ import abc import importlib -import os import warnings from numbers import Number @@ -307,7 +306,7 @@ def value_from(self, data, ep=None): t, d, time_support, kwargs = super().value_from(data, ep) return data.__class__(t=t, d=d, time_support=time_support, **kwargs) - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -335,6 +334,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -362,7 +363,7 @@ def count(self, *args, **kwargs): start end 0 100.0 800.0 """ - t, d, ep = super().count(*args, **kwargs) + t, d, ep = super().count(*args, dtype=dtype, **kwargs) return Tsd(t=t, d=d, time_support=ep) def bin_average(self, bin_size, ep=None, time_units="s"): @@ -475,8 +476,8 @@ def convolve(self, array, ep=None, trim="both"): Parameters ---------- array : array-like - One dimensional input array-like. - + 1-D or 2-D array with kernel(s) to be used for convolution. + First dimension is assumed to be time. ep : None, optional The epochs to apply the convolution trim : str, optional @@ -487,15 +488,19 @@ def convolve(self, array, ep=None, trim="both"): Tsd, TsdFrame or TsdTensor The convolved time series """ - assert is_array_like( - array - ), "Input should be a numpy array (or jax array if pynajax is installed)." - assert array.ndim == 1, "Input should be a one dimensional array." - assert trim in [ - "both", - "left", - "right", - ], "Unknow argument. trim should be 'both', 'left' or 'right'." + if not is_array_like(array): + raise IOError( + "Input should be a numpy array (or jax array if pynajax is installed)." + ) + + if len(array) == 0: + raise IOError("Input array is length 0") + + if array.ndim > 2: + raise IOError("Array should be 1 or 2 dimension.") + + if trim not in ["both", "left", "right"]: + raise IOError("Unknow argument. trim should be 'both', 'left' or 'right'.") time_array = self.index.values data_array = self.values @@ -505,7 +510,8 @@ def convolve(self, array, ep=None, trim="both"): starts = ep.start ends = ep.end else: - assert isinstance(ep, IntervalSet) + if not isinstance(ep, IntervalSet): + raise IOError("ep should be an object of type IntervalSet") starts = ep.start ends = ep.end idx = _restrict(time_array, starts, ends) @@ -514,7 +520,14 @@ def convolve(self, array, ep=None, trim="both"): new_data_array = _convolve(time_array, data_array, starts, ends, array, trim) - return self.__class__(t=time_array, d=new_data_array, time_support=ep) + kwargs_dict = dict(time_support=ep) + + nap_class = _get_class(new_data_array) + + if isinstance(self, TsdFrame) and array.ndim == 1: # keep columns + kwargs_dict["columns"] = self.columns + + return nap_class(t=time_array, d=new_data_array, **kwargs_dict) def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=True): """Smooth a time series with a gaussian kernel. @@ -569,18 +582,21 @@ def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=Tru Time series convolved with a gaussian kernel """ - assert isinstance(std, (int, float)), "std should be type int or float" - assert isinstance(size_factor, int), "size_factor should be of type int" - assert isinstance(norm, bool), "norm should be of type boolean" - assert isinstance(time_units, str), "time_units should be of type str" + if not isinstance(std, (int, float)): + raise IOError("std should be type int or float") + if not isinstance(size_factor, int): + raise IOError("size_factor should be of type int") + if not isinstance(norm, bool): + raise IOError("norm should be of type boolean") + if not isinstance(time_units, str): + raise IOError("time_units should be of type str") std = TsIndex.format_timestamps(np.array([std]), time_units)[0] std_size = int(self.rate * std) if windowsize is not None: - assert isinstance( - windowsize, (int, float) - ), "windowsize should be type int or float" + if not isinstance(windowsize, Number): + raise IOError("windowsize should be type int or float") windowsize = TsIndex.format_timestamps(np.array([windowsize]), time_units)[ 0 ] @@ -588,6 +604,9 @@ def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=Tru else: M = std_size * size_factor + if M % 2 == 0: + M += 1 + window = signal.windows.gaussian(M=M, std=std_size) if norm: @@ -611,12 +630,22 @@ def interpolate(self, ts, ep=None, left=None, right=None): right : None, optional Value to return for ts > tsd[-1], default is tsd[-1]. """ - assert isinstance( - ts, Base - ), "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" + if not isinstance(ts, Base): + raise IOError( + "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" + ) - if not isinstance(ep, IntervalSet): + if left is not None and not isinstance(left, Number): + raise IOError("Argument left should be of type float or int") + + if right is not None and not isinstance(right, Number): + raise IOError("Argument right should be of type float or int") + + if ep is None: ep = self.time_support + else: + if not isinstance(ep, IntervalSet): + raise IOError("ep should be an object of type IntervalSet") new_t = ts.restrict(ep).index @@ -805,23 +834,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = self._get_filename(filename) np.savez( filename, @@ -997,10 +1010,9 @@ def __getitem__(self, key, *args, **kwargs): if all(is_array_like(a) for a in [index, output]): if output.shape[0] == index.shape[0]: - - if isinstance(columns, pd.Index): - if not pd.api.types.is_integer_dtype(columns): - kwargs["columns"] = columns + # if isinstance(columns, pd.Index): + # if not pd.api.types.is_integer_dtype(columns): + kwargs["columns"] = columns return _get_class(output)( t=index, d=output, time_support=self.time_support, **kwargs @@ -1086,23 +1098,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = self._get_filename(filename) cols_name = self.columns if cols_name.dtype == np.dtype("O"): @@ -1413,24 +1409,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) - + filename = self._get_filename(filename) np.savez( filename, t=self.index.values, @@ -1591,7 +1570,7 @@ def value_from(self, data, ep=None): return data.__class__(t, d, time_support=time_support, **kwargs) - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -1619,6 +1598,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -1643,10 +1624,10 @@ def count(self, *args, **kwargs): And bincount automatically inherit ep as time support: >>> bincount.time_support - >>> start end - >>> 0 100.0 800.0 + start end + 0 100.0 800.0 """ - t, d, ep = super().count(*args, **kwargs) + t, d, ep = super().count(*args, dtype=dtype, **kwargs) return Tsd(t=t, d=d, time_support=ep) def fillna(self, value): @@ -1706,23 +1687,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = self._get_filename(filename) np.savez( filename, diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 6c052733..2052f696 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -4,7 +4,6 @@ """ -import os import warnings from collections import UserDict from collections.abc import Hashable @@ -21,7 +20,7 @@ from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like -from .utils import _get_terminal_size, convert_to_numpy_array +from .utils import _get_terminal_size, check_filename, convert_to_numpy_array def _union_intervals(i_sets): @@ -76,9 +75,9 @@ def __init__( Parameters ---------- - data : dict - Dictionary containing Ts/Tsd objects, keys should contain integer values or should be convertible - to integer. + data : dict or iterable + Dictionary or iterable of Ts/Tsd objects. The keys should be integer-convertible; if a non-dict iterator is + passed, its values will be used to create a dict with integer keys. time_support : IntervalSet, optional The time support of the TsGroup. Ts/Tsd objects will be restricted to the time support if passed. If no time support is specified, TsGroup will merge time supports from all the Ts/Tsd objects in data. @@ -101,15 +100,33 @@ def __init__( - If the converted keys are not unique, i.e. {1: ts_2, "2": ts_2} is valid, {1: ts_2, "1": ts_2} is invalid. """ + # Check input type + if time_units not in ["s", "ms", "us"]: + raise ValueError("Argument time_units should be 's', 'ms' or 'us'") + if not isinstance(bypass_check, bool): + raise TypeError("Argument bypass_check should be of type bool") + passed_time_support = False + + if isinstance(time_support, IntervalSet): + passed_time_support = True + else: + if time_support is not None: + raise TypeError("Argument time_support should be of type IntervalSet") + else: + passed_time_support = False + self._initialized = False + if not isinstance(data, dict): + data = dict(enumerate(data)) + # convert all keys to integer try: keys = [int(k) for k in data.keys()] except Exception: raise ValueError("All keys must be convertible to integer.") - # check that there were no floats with decimal points in keys.i + # check that there were no floats with decimal points in keys. # i.e. 0.5 is not a valid key if not all(np.allclose(keys[j], float(k)) for j, k in enumerate(data.keys())): raise ValueError("All keys must have integer value!}") @@ -121,6 +138,8 @@ def __init__( data = {keys[j]: data[k] for j, k in enumerate(data.keys())} self.index = np.sort(keys) + # Make sure data dict and index are ordered the same + data = {k: data[k] for k in self.index} self._metadata = pd.DataFrame(index=self.index, columns=["rate"], dtype="float") @@ -141,7 +160,7 @@ def __init__( ) # If time_support is passed, all elements of data are restricted prior to init - if isinstance(time_support, IntervalSet): + if passed_time_support: self.time_support = time_support if not bypass_check: data = {k: data[k].restrict(self.time_support) for k in self.index} @@ -187,6 +206,10 @@ def __getattr__(self, name): AttributeError If the requested attribute is not a metadata column. """ + # avoid infinite recursion when pickling due to + # self._metadata.column having attributes '__reduce__', '__reduce_ex__' + if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + raise AttributeError(name) # Check if the requested attribute is part of the metadata if name in self._metadata.columns: return self._metadata[name] @@ -588,7 +611,7 @@ def value_from(self, tsd, ep=None): cols = self._metadata.columns.drop("rate") return TsGroup(newgr, time_support=ep, **self._metadata[cols]) - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -616,6 +639,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -684,6 +709,12 @@ def count(self, *args, **kwargs): if isinstance(a, IntervalSet): ep = a + if dtype: + try: + dtype = np.dtype(dtype) + except Exception: + raise ValueError(f"{dtype} is not a valid numpy dtype.") + starts = ep.start ends = ep.end @@ -694,20 +725,28 @@ def count(self, *args, **kwargs): # Call it on first element to pre-allocate the array if len(self) >= 1: time_index, d = _count( - self.data[self.index[0]].index.values, starts, ends, bin_size + self.data[self.index[0]].index.values, + starts, + ends, + bin_size, + dtype=dtype, ) - count = np.zeros((len(time_index), len(self.index)), dtype=np.int64) + count = np.zeros((len(time_index), len(self.index)), dtype=dtype) count[:, 0] = d for i in range(1, len(self.index)): count[:, i] = _count( - self.data[self.index[i]].index.values, starts, ends, bin_size + self.data[self.index[i]].index.values, + starts, + ends, + bin_size, + dtype=dtype, )[1] return TsdFrame(t=time_index, d=count, time_support=ep, columns=self.index) else: - time_index, _ = _count(np.array([]), starts, ends, bin_size) + time_index, _ = _count(np.array([]), starts, ends, bin_size, dtype=dtype) return TsdFrame( t=time_index, d=np.empty((len(time_index), 0)), time_support=ep ) @@ -1306,23 +1345,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = check_filename(filename) dicttosave = {"type": np.array(["TsGroup"], dtype=np.str_)} for k in self._metadata.columns: @@ -1364,3 +1387,59 @@ def save(self, filename): np.savez(filename, **dicttosave) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load a Tsd object from a npz file. + + Parameters + ---------- + file : str + The opened npz file + + Returns + ------- + Tsd + The Tsd object + """ + + times = file["t"] + index = file["index"] + has_data = "d" in file.keys() + time_support = IntervalSet(file["start"], file["end"]) + + if has_data: + data = file["data"] + + if "keys" in file.keys(): + keys = file["keys"] + else: + keys = np.unique(index) + + group = {} + for key in keys: + filtering_index = index == key + t = times[filtering_index] + + if has_data: + group[key] = Tsd( + t=t, + d=data[filtering_index], + time_support=time_support, + ) + else: + group[key] = Ts(t=t, time_support=time_support) + + tsgroup = cls(group, time_support=time_support, bypass_check=True) + + metainfo = {} + not_info_keys = {"start", "end", "t", "index", "d", "rate", "keys"} + + for k in set(file.keys()) - not_info_keys: + tmp = file[k] + if len(tmp) == len(tsgroup): + metainfo[k] = tmp + + tsgroup.set_info(**metainfo) + return tsgroup diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 17d90546..711aa0c8 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -6,6 +6,7 @@ import warnings from itertools import combinations from numbers import Number +from pathlib import Path import numpy as np @@ -122,12 +123,21 @@ def is_array_like(obj): has_ndim = hasattr(obj, "ndim") # Check for indexability (try to access the first element) + try: obj[0] is_indexable = True except Exception: is_indexable = False + if not is_indexable: + if hasattr(obj, "__len__"): + try: + if len(obj) == 0: + is_indexable = True # Could be an empty array + except Exception: + is_indexable = False + # Check for iterable property try: iter(obj) @@ -394,3 +404,35 @@ def __getitem__(self, key): raise IndexError else: raise IndexError + + +def check_filename(filename): + """Check if the filename is valid and return the path + + Parameters + ---------- + filename : str or Path + The filename + + Returns + ------- + Path + The path to the file + + Raises + ------ + RuntimeError + If the filename is a directory or the parent does not exist + """ + filename = Path(filename).resolve() + + if filename.is_dir(): + raise RuntimeError("Invalid filename input. {} is directory.".format(filename)) + + filename = filename.with_suffix(".npz") + + parent_folder = filename.parent + if not parent_folder.exists(): + raise RuntimeError("Path {} does not exist.".format(parent_folder)) + + return filename diff --git a/pynapple/io/cnmfe.py b/pynapple/io/cnmfe.py index c3266362..d6ad293e 100644 --- a/pynapple/io/cnmfe.py +++ b/pynapple/io/cnmfe.py @@ -11,7 +11,7 @@ # @Last Modified by: gviejo # @Last Modified time: 2023-11-16 13:14:54 -import os +from pathlib import Path from pynwb import NWBHDF5IO @@ -43,13 +43,14 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) - self.load_cnmfe_nwb(path) + self.load_cnmfe_nwb() - def load_cnmfe_nwb(self, path): + def load_cnmfe_nwb(self): """ Load the calcium transient and spatial footprint from nwb @@ -58,13 +59,6 @@ def load_cnmfe_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() @@ -110,13 +104,14 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) - self.load_cnmfe_nwb(path) + self.load_cnmfe_nwb() - def load_cnmfe_nwb(self, path): + def load_cnmfe_nwb(self): """ Load the calcium transient and spatial footprint from nwb @@ -125,13 +120,6 @@ def load_cnmfe_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() @@ -178,13 +166,14 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) - self.load_cnmfe_nwb(path) + self.load_cnmfe_nwb() - def load_cnmfe_nwb(self, path): + def load_cnmfe_nwb(self): """ Load the calcium transient and spatial footprint from nwb @@ -193,13 +182,6 @@ def load_cnmfe_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() diff --git a/pynapple/io/folder.py b/pynapple/io/folder.py index 60042bd9..a35af18d 100644 --- a/pynapple/io/folder.py +++ b/pynapple/io/folder.py @@ -1,21 +1,12 @@ -#!/usr/bin/env python - -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-05-15 15:32:24 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-08-06 17:37:23 - """ The Folder class helps to navigate a hierarchical data tree. """ - import json -import os import string from collections import UserDict from datetime import datetime +from pathlib import Path from rich.console import Console # , ConsoleOptions, RenderResult from rich.panel import Panel @@ -30,27 +21,29 @@ def _find_files(path, extension=".npz"): Parameters ---------- - path : TYPE - Description + path : str or Path + The directory path where files will be searched. extension : str, optional - Description + The file extension to look for, default is ".npz". Returns ------- - TYPE - Description + dict + Dictionary with filenames (without extension and whitespace) as keys + and NPZFile or NWBFile objects as values. """ + extension = extension if extension.startswith(".") else "." + extension + path = Path(path) # Ensure path is a Path object files = {} - for f in os.scandir(path): - if f.is_file() and f.name.endswith(extension): - if extension == "npz": - filename = os.path.splitext(os.path.basename(f.path))[0] - filename.translate({ord(c): None for c in string.whitespace}) - files[filename] = NPZFile(f.path) - elif extension == "nwb": - filename = os.path.splitext(os.path.basename(f.path))[0] - filename.translate({ord(c): None for c in string.whitespace}) - files[filename] = NWBFile(f.path) + extensions_dict = {".npz": NPZFile, ".nwb": NWBFile} + assert extension in extensions_dict.keys(), f"Extension {extension} not supported" + + for f in path.iterdir(): + if f.is_file() and f.suffix == extension: + filename = f.stem + filename = filename.translate({ord(c): None for c in string.whitespace}) + files[filename] = extensions_dict[extension](f) + return files @@ -108,9 +101,9 @@ def __init__(self, path): # , exclude=(), max_depth=4): path : str Path to the folder """ - path = path.rstrip("/") + path = Path(path) self.path = path - self.name = os.path.basename(path) + self.name = self.path.name self._basic_view = Tree( ":open_file_folder: {}".format(self.name), guide_style="blue" ) @@ -118,16 +111,15 @@ def __init__(self, path): # , exclude=(), max_depth=4): # Search sub-folders subfolds = [ - f.path - for f in os.scandir(path) - if f.is_dir() and not f.name.startswith(".") + p for p in path.iterdir() if p.is_dir() and not p.name.startswith(".") ] + subfolds.sort() self.subfolds = {} for s in subfolds: - sub = os.path.basename(s) + sub = s.name self.subfolds[sub] = Folder(s) self._basic_view.add(":open_file_folder: [blue]" + sub) @@ -244,14 +236,14 @@ def save(self, name, obj, description=""): description : str, optional Metainformation added as a json sidecar. """ - filepath = os.path.join(self.path, name) + filepath = self.path / (name + ".npz") obj.save(filepath) - self.npz_files[name] = NPZFile(filepath + ".npz") + self.npz_files[name] = NPZFile(filepath) self.data[name] = obj metadata = {"time": str(datetime.now()), "info": str(description)} - with open(os.path.join(self.path, name + ".json"), "w") as ff: + with open(self.path / (name + ".json"), "w") as ff: json.dump(metadata, ff, indent=2) # regenerate the tree view @@ -295,19 +287,18 @@ def metadata(self, name): Name of the npz file """ # Search for json first - json_filename = os.path.join(self.path, name + ".json") - if os.path.isfile(json_filename): + json_filename = self.path / (name + ".json") + title = self.path / (name + ".npz") + if json_filename.exists(): with open(json_filename, "r") as ff: metadata = json.load(ff) text = "\n".join([" : ".join(it) for it in metadata.items()]) - panel = Panel.fit( - text, border_style="green", title=os.path.join(self.path, name + ".npz") - ) + panel = Panel.fit(text, border_style="green", title=str(title)) else: panel = Panel.fit( "No metadata", border_style="red", - title=os.path.join(self.path, name + ".npz"), + title=str(title), ) with Console() as console: console.print(panel) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 4795e0b2..cedb779b 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -4,20 +4,44 @@ # @Author: Guillaume Viejo # @Date: 2023-07-05 16:03:25 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-02 14:32:25 +# @Last Modified time: 2024-08-02 11:16:07 -""" -File classes help to validate and load pynapple objects or NWB files. -Data are always lazy-loaded. -Both classes behaves like dictionary. -""" -import os +from pathlib import Path import numpy as np from .. import core as nap +# +EXPECTED_ENTRIES = { + "TsGroup": {"t", "start", "end", "index"}, + "TsdFrame": {"t", "d", "start", "end", "columns"}, + "TsdTensor": {"t", "d", "start", "end"}, + "Tsd": {"t", "d", "start", "end"}, + "Ts": {"t", "start", "end"}, + "IntervalSet": {"start", "end"}, +} + + +def _find_class_from_variables(file_variables, data_ndims=None): + if data_ndims is not None: + + assert EXPECTED_ENTRIES["Tsd"].issubset(file_variables) + + if data_ndims == 1: + return "Tsd" + elif data_ndims == 2: + return "TsdFrame" + else: + return "TsdTensor" + + for possible_type, expected_variables in EXPECTED_ENTRIES.items(): + if expected_variables.issubset(file_variables): + return possible_type + + return "npz" + class NPZFile(object): """Class that points to a NPZ file that can be loaded as a pynapple object. @@ -44,37 +68,25 @@ def __init__(self, path): path : str Valid path to a NPZ file """ + path = Path(path) self.path = path - self.name = os.path.basename(path) + self.name = path.name self.file = np.load(self.path, allow_pickle=True) - self.type = "" - - # First check if type is explicitely defined - possible = ["Ts", "Tsd", "TsdFrame", "TsdTensor", "TsGroup", "IntervalSet"] - if "type" in self.file.keys(): - if len(self.file["type"]) == 1: - if isinstance(self.file["type"][0], np.str_): - if self.file["type"] in possible: - self.type = self.file["type"][0] - - # Second check manually - if self.type == "": - k = set(self.file.keys()) - if {"t", "start", "end", "index"}.issubset(k): - self.type = "TsGroup" - elif {"t", "d", "start", "end", "columns"}.issubset(k): - self.type = "TsdFrame" - elif {"t", "d", "start", "end"}.issubset(k): - if self.file["d"].ndim == 1: - self.type = "Tsd" - else: - self.type = "TsdTensor" - elif {"t", "start", "end"}.issubset(k): - self.type = "Ts" - elif {"start", "end"}.issubset(k): - self.type = "IntervalSet" - else: - self.type = "npz" + type_ = "" + + # First check if type is explicitely defined in the file: + try: + type_ = self.file["type"][0] + assert type_ in EXPECTED_ENTRIES.keys() + + # if not, use heuristics: + except (KeyError, IndexError, AssertionError): + file_variables = set(self.file.keys()) + data_ndims = self.file["d"].ndim if "d" in file_variables else None + + type_ = _find_class_from_variables(file_variables, data_ndims) + + self.type = type_ def load(self): """Load the NPZ file @@ -86,73 +98,5 @@ def load(self): """ if self.type == "npz": return self.file - else: - time_support = nap.IntervalSet(self.file["start"], self.file["end"]) - if self.type == "TsGroup": - - times = self.file["t"] - index = self.file["index"] - has_data = False - if "d" in self.file.keys(): - data = self.file["data"] - has_data = True - - if "keys" in self.file.keys(): - keys = self.file["keys"] - else: - keys = np.unique(index) - - group = {} - for k in keys: - if has_data: - group[k] = nap.Tsd( - t=times[index == k], - d=data[index == k], - time_support=time_support, - ) - else: - group[k] = nap.Ts( - t=times[index == k], time_support=time_support - ) - - tsgroup = nap.TsGroup( - group, time_support=time_support, bypass_check=True - ) - - metainfo = {} - for k in set(self.file.keys()) - { - "start", - "end", - "t", - "index", - "d", - "rate", - "keys", - }: - tmp = self.file[k] - if len(tmp) == len(tsgroup): - metainfo[k] = tmp - tsgroup.set_info(**metainfo) - return tsgroup - - elif self.type == "TsdFrame": - return nap.TsdFrame( - t=self.file["t"], - d=self.file["d"], - time_support=time_support, - columns=self.file["columns"], - ) - elif self.type == "TsdTensor": - return nap.TsdTensor( - t=self.file["t"], d=self.file["d"], time_support=time_support - ) - elif self.type == "Tsd": - return nap.Tsd( - t=self.file["t"], d=self.file["d"], time_support=time_support - ) - elif self.type == "Ts": - return nap.Ts(t=self.file["t"], time_support=time_support) - elif self.type == "IntervalSet": - return time_support - else: - return self.file + + return getattr(nap, self.type)._from_npz_reader(self.file) diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index 70f33956..e94cfcb4 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -15,6 +15,7 @@ import warnings from collections import UserDict from numbers import Number +from pathlib import Path import numpy as np import pynwb @@ -386,22 +387,21 @@ def __init__(self, file, lazy_loading=True): RuntimeError If file is not an instance of NWBFile """ - if isinstance(file, str): - if os.path.exists(file): - self.path = file - self.name = os.path.basename(file).split(".")[0] - self.io = NWBHDF5IO(file, "r") - self.nwb = self.io.read() - else: - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) - elif isinstance(file, pynwb.file.NWBFile): + # TODO: do we really need to have instantiation from file and object in the same place? + + if isinstance(file, pynwb.file.NWBFile): self.nwb = file self.name = self.nwb.session_id - else: - raise RuntimeError( - "unrecognized argument. Please provide path to a valid NWB file or open NWB file." - ) + path = Path(file) + + if path.exists(): + self.path = path + self.name = path.stem + self.io = NWBHDF5IO(path, "r") + self.nwb = self.io.read() + else: + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) self.data = _extract_compatible_data_from_nwbfile(self.nwb) self.key_to_id = {k: self.data[k]["id"] for k in self.data.keys()} @@ -480,3 +480,7 @@ def __getitem__(self, key): return self.data[key] else: raise KeyError("Can't find key {} in group index.".format(key)) + + def close(self): + """Close the NWB file""" + self.io.close() diff --git a/pynapple/io/loader.py b/pynapple/io/loader.py index e08c1b95..8e03cc8a 100644 --- a/pynapple/io/loader.py +++ b/pynapple/io/loader.py @@ -11,6 +11,7 @@ """ import os import warnings +from pathlib import Path import pandas as pd from pynwb import NWBHDF5IO, TimeSeries @@ -56,23 +57,28 @@ class BaseLoader(object): """ def __init__(self, path=None): - self.path = path + self.path = Path(path) - file_found = False # Check if a pynapplenwb folder exist - if self.path is not None: - nwb_path = os.path.join(self.path, "pynapplenwb") - if os.path.exists(nwb_path): - files = os.listdir(nwb_path) - if len([f for f in files if f.endswith(".nwb")]): - file_found = True - self.load_data(path) - - # Starting the GUI - if not file_found: + nwb_path = self.path / "pynapplenwb" + files = list(nwb_path.glob("*.nwb")) + + if len(files) > 0: + self.load_data() + else: raise RuntimeError(get_error_text(path)) - def load_data(self, path): + @property + def nwbfilepath(self): + try: + nwbfilepath = next(self.path.glob("pynapplenwb/*nwb")) + except StopIteration: + raise FileNotFoundError( + "No NWB file found in {}".format(self.path / "pynapplenwb") + ) + return nwbfilepath + + def load_data(self): """ Load NWB data saved with pynapple in the pynapplenwb folder @@ -81,11 +87,6 @@ def load_data(self, path): path : str Path to the session folder """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) io = NWBHDF5IO(self.nwbfilepath, "r+") nwbfile = io.read() diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index a3bab79e..03128adc 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -4,7 +4,8 @@ Various io functions """ -import os +import warnings +from pathlib import Path from xml.dom import minidom import numpy as np @@ -22,7 +23,7 @@ from .suite2p import Suite2P -def load_file(path): +def load_file(path, lazy_loading=None): """Load file. Current format supported is (npz,nwb,) .npz -> If the file is compatible with a pynapple format, the function will return a pynapple object. @@ -34,6 +35,9 @@ def load_file(path): ---------- path : str Path to the file + lazy_loading : bool, optional default True + Lazy loading of the data. If not specified, the function will use the defaults + True. Works only with NWB files. Returns ------- @@ -45,15 +49,23 @@ def load_file(path): FileNotFoundError If file is missing """ - if os.path.isfile(path): - if path.endswith(".npz"): - return NPZFile(path).load() - elif path.endswith(".nwb"): - return NWBFile(path) - else: - raise RuntimeError("File format not supported") + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File {path} does not exist") + + if path.suffix == ".npz": + if lazy_loading: + warnings.warn("Lazy loading is not supported for NPZ files") + return NPZFile(path).load() + + elif path.suffix == ".nwb": + # preserves class init default: + kwargs_for_lazyloading = ( + {} if lazy_loading is None else {"lazy_loading": lazy_loading} + ) + return NWBFile(path, **kwargs_for_lazyloading) else: - raise FileNotFoundError("File {} does not exist".format(path)) + raise RuntimeError("File format not supported") def load_folder(path): @@ -76,13 +88,13 @@ def load_folder(path): RuntimeError If folder is missing """ - if os.path.isdir(path): - return Folder(path) - else: - raise RuntimeError("Folder {} does not exist".format(path)) + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Folder {path} does not exist") + return Folder(path) -def load_session(path=None, session_type=None): +def load_session(path, session_type=None): """ %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % WARNING : THIS FUNCTION IS DEPRECATED % @@ -112,9 +124,8 @@ def load_session(path=None, session_type=None): A class holding all the data from the session. """ - if path: - if not os.path.isdir(path): - raise RuntimeError("Path {} is not found.".format(path)) + path = Path(path) + assert path.exists(), f"Folder {path} does not exist" if isinstance(session_type, str): session_type = session_type.lower() @@ -184,13 +195,14 @@ def load_eeg( """ # Need to check if a xml file exists - path = os.path.dirname(filepath) - basename = os.path.basename(filepath).split(".")[0] - listdir = os.listdir(path) + filepath = Path(filepath) + path = filepath.parent + basename = filepath.name.split(".")[0] + listdir = list(path.glob("*")) if frequency is None or n_channels is None: if basename + ".xml" in listdir: - xmlpath = os.path.join(path, basename + ".xml") + xmlpath = path / (basename + ".xml") xmldoc = minidom.parse(xmlpath) else: raise RuntimeError( @@ -268,18 +280,12 @@ def append_NWB_LFP(path, lfp, channel=None): If no channel is specify when passing a Tsd """ - new_path = os.path.join(path, "pynapplenwb") + path = Path(path) + new_path = path / "pynapplenwb" nwb_path = "" - if os.path.exists(new_path): - nwbfilename = [f for f in os.listdir(new_path) if f.endswith(".nwb")] - if len(nwbfilename): - nwb_path = os.path.join(path, "pynapplenwb", nwbfilename[0]) - else: - nwbfilename = [f for f in os.listdir(path) if f.endswith(".nwb")] - if len(nwbfilename): - nwb_path = os.path.join(path, "pynapplenwb", nwbfilename[0]) - - if len(nwb_path) == 0: + try: + nwb_path = next(new_path.glob("*.nwb")) + except StopIteration: raise RuntimeError("Can't find nwb file in {}".format(path)) if isinstance(lfp, nap.TsdFrame): diff --git a/pynapple/io/neurosuite.py b/pynapple/io/neurosuite.py index 952853a5..ca05ded2 100755 --- a/pynapple/io/neurosuite.py +++ b/pynapple/io/neurosuite.py @@ -1,10 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-02-02 20:45:09 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-16 13:21:34 - """ > :warning: **DEPRECATED**: This will be removed in version 1.0.0. Check [nwbmatic](https://github.com/pynapple-org/nwbmatic) or [neuroconv](https://github.com/catalystneuro/neuroconv) instead. @@ -12,8 +5,9 @@ @author: Guillaume Viejo """ -import os + import sys +from pathlib import Path import numpy as np import pandas as pd @@ -37,14 +31,15 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name self.time_support = None super().__init__(path) - self.load_nwb_spikes(path) + self.load_nwb_spikes() - def load_nwb_spikes(self, path): + def load_nwb_spikes(self): """ Read the NWB spikes to extract the spike times. @@ -58,11 +53,6 @@ def load_nwb_spikes(self, path): TYPE Description """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() @@ -129,16 +119,16 @@ def load_lfp( The lfp in a time series format """ if filename is not None: - filepath = os.path.join(self.path, filename) + filepath = self.path / filename else: - listdir = os.listdir(self.path) - eegfile = [f for f in listdir if f.endswith(extension)] + eegfile = list(filepath.glob(f"*{extension}")) + if not len(eegfile): raise RuntimeError( "Path {} contains no {} files;".format(self.path, extension) ) - filepath = os.path.join(self.path, eegfile[0]) + filepath = eegfile[0] self.load_neurosuite_xml(self.path) @@ -196,13 +186,14 @@ def read_neuroscope_intervals(self, name=None, path2file=None): Contains two columns corresponding to the start and end of the intervals. """ - if name: - isets = self.load_nwb_intervals(name) - if isinstance(isets, nap.IntervalSet): - return isets + # if name: + # isets = self.load_nwb_intervals(name) + # if isinstance(isets, nap.IntervalSet): + # return isets + if name is not None and path2file is None: - path2file = os.path.join(self.path, self.basename + "." + name + ".evt") - if path2file is not None: + path2file = self.path / (self.basename + "." + name + ".evt") + if path2file is not None: # TODO maybe useless conditional? try: # df = pd.read_csv(path2file, delimiter=' ', usecols = [0], header = None) tmp = np.genfromtxt(path2file)[:, 0] @@ -213,7 +204,7 @@ def read_neuroscope_intervals(self, name=None, path2file=None): if name is None: name = path2file.split(".")[-2] print("*** saving file in the nwb as", name) - self.save_nwb_intervals(isets, name) + # self.save_nwb_intervals(isets, name) else: raise ValueError("specify a valid path") return isets @@ -244,7 +235,7 @@ def write_neuroscope_intervals(self, extension, isets, name): ) ).T.flatten() - evt_file = os.path.join(self.path, self.basename + extension) + evt_file = self.path / (self.basename + extension) f = open(evt_file, "w") for t, n in zip(datatowrite, texttowrite): @@ -281,7 +272,7 @@ def load_mean_waveforms(self, epoch=None, waveform_window=None, spike_count=1000 waveform_window = nap.IntervalSet(start=-0.5, end=1, time_units="ms") spikes = self.spikes - if not os.path.exists(self.path): # check if path exists + if not self.path.exists(): # check if path exists print("The path " + self.path + " doesn't exist; Exiting ...") sys.exit() @@ -304,11 +295,12 @@ def load_mean_waveforms(self, epoch=None, waveform_window=None, spike_count=1000 epend = int(epoch.as_units("s")["end"].values[0] * fs) # Find dat file - files = os.listdir(self.path) - dat_files = np.sort([f for f in files if "dat" in f and f[0] != "."]) + # files = os.listdir(self.path) + # dat_files = np.sort([f for f in files if "dat" in f and f[0] != "."]) # Need n_samples collected in the entire recording from dat file to load - file = os.path.join(self.path, dat_files[0]) + # file = self.path / dat_files[0] + file = next(self.path.glob("^[^.][^.]*.dat")) f = open( file, "rb" ) # open file to get number of samples collected in the entire recording diff --git a/pynapple/io/phy.py b/pynapple/io/phy.py index ae0b79a8..92568463 100644 --- a/pynapple/io/phy.py +++ b/pynapple/io/phy.py @@ -6,7 +6,7 @@ @author: Sara Mahallati, Guillaume Viejo """ -import os +from pathlib import Path import numpy as np from pynwb import NWBHDF5IO @@ -29,14 +29,16 @@ def __init__(self, path): path : str or Path object The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + + self.basename = path.name self.time_support = None super().__init__(path) - self.load_nwb_spikes(path) + self.load_nwb_spikes() - def load_nwb_spikes(self, path): + def load_nwb_spikes(self): """Read the NWB spikes to extract the spike times. Returns @@ -44,11 +46,6 @@ def load_nwb_spikes(self, path): TYPE Description """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() diff --git a/pynapple/io/suite2p.py b/pynapple/io/suite2p.py index 3a5fb045..c3242241 100644 --- a/pynapple/io/suite2p.py +++ b/pynapple/io/suite2p.py @@ -13,7 +13,7 @@ """ -import os +from pathlib import Path import numpy as np import pandas as pd @@ -60,7 +60,8 @@ def __init__(self, path): path : str The path of the session """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) @@ -75,13 +76,6 @@ def load_suite2p_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 2e1af412..b7d9576c 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -4,6 +4,13 @@ compute_eventcorrelogram, ) from .decoding import decode_1d, decode_2d +from .filtering import ( + apply_bandpass_filter, + apply_bandstop_filter, + apply_highpass_filter, + apply_lowpass_filter, + get_filter_frequency_response, +) from .perievent import ( compute_event_trigger_average, compute_perievent, @@ -15,6 +22,10 @@ shift_timestamps, shuffle_ts_intervals, ) +from .spectrum import ( + compute_mean_power_spectral_density, + compute_power_spectral_density, +) from .tuning_curves import ( compute_1d_mutual_info, compute_1d_tuning_curves, @@ -24,3 +35,4 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) +from .wavelets import compute_wavelet_transform, generate_morlet_filterbank diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 9d9ea5ff..b528a52d 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -227,7 +227,6 @@ def _perievent_trigger_average( def _perievent_continuous( time_array, data_array, time_target_array, starts, ends, windowsize ): - idx, slice_idx, N_target, w_starts = _jitcontinuous_perievent( time_array, time_target_array, starts, ends, windowsize ) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py new file mode 100644 index 00000000..dc13f967 --- /dev/null +++ b/pynapple/process/filtering.py @@ -0,0 +1,515 @@ +"""Filtering module.""" + +import inspect +from collections.abc import Iterable +from functools import wraps +from numbers import Number + +import numpy as np +import pandas as pd +from scipy.signal import butter, sosfiltfilt, sosfreqz + +from .. import core as nap + + +def _validate_filtering_inputs(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Validate each positional argument + sig = inspect.signature(func) + kwargs = sig.bind_partial(*args, **kwargs).arguments + + cutoff = kwargs["cutoff"] + filter_type = kwargs["filter_type"] + if filter_type in ["lowpass", "highpass"] and not isinstance(cutoff, Number): + raise ValueError( + f"{filter_type} filter require a single number. {cutoff} provided instead." + ) + if filter_type in ["bandpass", "bandstop"]: + if ( + not isinstance(cutoff, Iterable) + or len(cutoff) != 2 + or not all(isinstance(fq, Number) for fq in cutoff) + ): + raise ValueError( + f"{filter_type} filter require a tuple of two numbers. {cutoff} provided instead." + ) + + if "fs" in kwargs: + if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number): + raise ValueError( + "Invalid value for 'fs'. Parameter 'fs' should be of type float or int" + ) + + if "order" in kwargs: + if not isinstance(kwargs["order"], int): + raise ValueError( + "Invalid value for 'order': Parameter 'order' should be of type int" + ) + + if "transition_bandwidth" in kwargs: + if not isinstance(kwargs["transition_bandwidth"], float): + raise ValueError( + "Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float" + ) + + # Call the original function with validated inputs + return func(**kwargs) + + return wrapper + + +def _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order=4): + """Calls scipy butter""" + return butter(order, cutoff, btype=filter_type, fs=sampling_frequency, output="sos") + + +def _compute_butterworth_filter( + data, cutoff, sampling_frequency=None, filter_type="bandpass", order=4 +): + """ + Apply a Butterworth filter to the provided signal. + """ + sos = _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order) + + if nap.utils.get_backend() == "jax": + from pynajax.jax_process_filtering import jax_sosfiltfilt + + out = jax_sosfiltfilt( + sos, + data.index.values, + data.values, + data.time_support.start, + data.time_support.end, + ) + + else: + out = np.zeros_like(data.d) + for ep in data.time_support: + slc = data.get_slice(start=ep.start[0], end=ep.end[0]) + out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) + + kwargs = dict(t=data.t, d=out, time_support=data.time_support) + if isinstance(data, nap.TsdFrame): + kwargs["columns"] = data.columns + return data.__class__(**kwargs) + + +def _compute_spectral_inversion(kernel): + """ + Compute the spectral inversion. + Parameters + ---------- + kernel: ndarray + + Returns + ------- + ndarray + """ + kernel *= -1.0 + kernel[len(kernel) // 2] = 1.0 + kernel[len(kernel) // 2] + return kernel + + +def _get_windowed_sinc_kernel( + fc, filter_type, sampling_frequency, transition_bandwidth=0.02 +): + """ + Get the windowed-sinc kernel. + Smith, S. (2003). Digital signal processing: a practical guide for engineers and scientists. + Chapter 16, equation 16-4 + + Parameters + ---------- + fc: float or tuple of float + Cutting frequency in Hz. Single float for 'lowpass' and 'highpass'. Tuple of float for + 'bandpass' and 'bandstop'. + filter_type: str + Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'. + sampling_frequency: float + Sampling frequency in Hz. + transition_bandwidth: float + Percentage between 0 and 0.5 + Returns + ------- + np.ndarray + """ + M = int(np.rint(4.0 / transition_bandwidth)) + x = np.arange(-(M // 2), 1 + (M // 2)) + fc = np.transpose(np.atleast_2d(fc / sampling_frequency)) + kernel = np.sinc(2 * fc * x) + kernel = kernel * np.blackman(len(x)) + kernel = np.transpose(kernel) + kernel = kernel / kernel.sum(0) + + if filter_type == "lowpass": + return kernel.flatten() + elif filter_type == "highpass": + return _compute_spectral_inversion(kernel.flatten()) + elif filter_type == "bandstop": + kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) + kernel = np.sum(kernel, axis=1) + return kernel + elif filter_type == "bandpass": + kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) + kernel = _compute_spectral_inversion(np.sum(kernel, axis=1)) + return kernel + else: + raise ValueError + + +def _compute_windowed_sinc_filter( + data, freq, filter_type, sampling_frequency, transition_bandwidth=0.02 +): + """ + Apply a windowed-sinc filter to the provided signal. + + Parameters + ---------- + data: Tsd, TsdFrame or TsdTensor + + freq: float or tuple of float + Cutting frequency in Hz. Single float for 'lowpass' and 'highpass'. Tuple of float for + 'bandpass' and 'bandstop'. + sampling_frequency: float + Sampling frequency in Hz. + filter_type: str + Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'. + transition_bandwidth: float + Percentage between 0 and 0.5 + Returns + ------- + Tsd, TsdFrame or TsdTensor + """ + kernel = _get_windowed_sinc_kernel( + freq, filter_type, sampling_frequency, transition_bandwidth + ) + return data.convolve(kernel) + + +@_validate_filtering_inputs +def _compute_filter( + data, + cutoff, + fs=None, + mode="butter", + order=4, + transition_bandwidth=0.02, + filter_type="bandpass", +): + """ + Filter the signal. + """ + if not isinstance(data, nap.time_series.BaseTsd): + raise ValueError( + f"Invalid value: {data}. First argument should be of type Tsd, TsdFrame or TsdTensor" + ) + + if np.any(np.isnan(data)): + raise ValueError( + "The input signal contains NaN values, which are not supported for filtering. " + "Please remove or handle NaNs before applying the filter. " + "You can use the `dropna()` method to drop all NaN values." + ) + + if fs is None: + fs = data.rate + + cutoff = np.array(cutoff, dtype=float) + + if mode == "butter": + return _compute_butterworth_filter( + data, cutoff, fs, filter_type=filter_type, order=order + ) + if mode == "sinc": + return _compute_windowed_sinc_filter( + data, cutoff, filter_type, fs, transition_bandwidth=transition_bandwidth + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + + +def apply_bandpass_filter( + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a band-pass filter to the provided signal. + Mode can be : + + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + cutoff : (Numeric, Numeric) + Cutoff frequencies in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `cutoff` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="bandpass", + ) + + +def apply_bandstop_filter( + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a band-stop filter to the provided signal. + Mode can be : + + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + cutoff : (Numeric, Numeric) + Cutoff frequencies in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `cutoff` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="bandstop", + ) + + +def apply_highpass_filter( + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a high-pass filter to the provided signal. + Mode can be : + + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + cutoff : Numeric + Cutoff frequency in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `cutoff` is not a number. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="highpass", + ) + + +def apply_lowpass_filter( + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a low-pass filter to the provided signal. + Mode can be : + + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + cutoff : Numeric + Cutoff frequency in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `cutoff` is not a number. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="lowpass", + ) + + +@_validate_filtering_inputs +def get_filter_frequency_response( + cutoff, fs, filter_type, mode, order=4, transition_bandwidth=0.02 +): + """ + Utility function to evaluate the frequency response of a particular type of filter. The arguments are the same + as the function `apply_lowpass_filter`, `apply_highpass_filter`, `apply_bandpass_filter` and + `apply_bandstop_filter`. + + This function returns a pandas Series object with the index as frequencies. + + Parameters + ---------- + cutoff : Numeric or tuple of Numeric + Cutoff frequency in Hz. + fs : float + The sampling frequency of the signal in Hz. + filter_type: str + Can be "lowpass", "highpass", "bandpass" or "bandstop" + mode: str + Can be "butter" or "sinc". + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + pandas.Series + """ + cutoff = np.array(cutoff) + + if mode == "butter": + sos = _get_butter_coefficients(cutoff, filter_type, fs, order) + w, h = sosfreqz(sos, worN=1024, fs=fs) + return pd.Series(index=w, data=np.abs(h)) + if mode == "sinc": + kernel = _get_windowed_sinc_kernel( + cutoff, filter_type, fs, transition_bandwidth + ) + fft_result = np.fft.fft(kernel) + fft_result = np.fft.fftshift(fft_result) + fft_freq = np.fft.fftfreq(n=len(kernel), d=1 / fs) + fft_freq = np.fft.fftshift(fft_freq) + return pd.Series( + index=fft_freq[fft_freq >= 0], data=np.abs(fft_result[fft_freq >= 0]) + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") diff --git a/pynapple/process/spectrum.py b/pynapple/process/spectrum.py new file mode 100644 index 00000000..30668c06 --- /dev/null +++ b/pynapple/process/spectrum.py @@ -0,0 +1,207 @@ +""" +# Power spectral density + +This module contains functions to compute power spectral density and mean power spectral density. + +""" + +from numbers import Number + +import numpy as np +import pandas as pd +from scipy import signal + +from .. import core as nap + + +def compute_power_spectral_density( + sig, fs=None, ep=None, full_range=False, norm=False, n=None +): + """ + Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal + ep : None or pynapple.IntervalSet, optional + The epoch to calculate the fft on. Must be length 1. + full_range : bool, optional + If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude + n: int, optional + Length of the transformed axis of the output. If n is smaller than the length of the input, + the input is cropped. If it is larger, the input is padded with zeros. If n is not given, + the length of the input along the axis specified by axis is used. + + Returns + ------- + pandas.DataFrame + Time frequency representation of the input signal, indexes are frequencies, values + are powers. + + Notes + ----- + compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep + parameter otherwise will be sig.time_support, but it must only be a single epoch. + """ + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + if len(ep) != 1: + raise ValueError("Given epoch (or signal time_support) must have length 1") + if fs is None: + fs = sig.rate + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + + fft_result = np.fft.fft(sig.restrict(ep).values, n=n, axis=0) + if n is None: + n = len(sig.restrict(ep)) + fft_freq = np.fft.fftfreq(n, 1 / fs) + + if norm: + fft_result = fft_result / fft_result.shape[0] + + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + + if not full_range: + return ret.loc[ret.index >= 0] + return ret + + +def compute_mean_power_spectral_density( + sig, + interval_size, + fs=None, + ep=None, + full_range=False, + norm=False, + time_unit="s", +): + """ + Compute mean power spectral density by averaging FFT over epochs of same size. + + The parameter `interval_size` controls the duration of the epochs. + + To imporve frequency resolution, the signal is multiplied by a Hamming window. + + Note that this function assumes a constant sampling rate for `sig`. + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Signal with equispaced samples + interval_size : Number + Epochs size to compute to average the FFT across + fs : None, optional + Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` + ep : None or pynapple.IntervalSet, optional + The `IntervalSet` to calculate the fft on. Can be any length. + full_range : bool, optional + If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude + time_unit : str, optional + Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') + + Returns + ------- + pandas.DataFrame + Power spectral density. + + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.arange(0, 1, 1/1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> mpsd = nap.compute_mean_power_spectral_density(signal, 0.1) + + Raises + ------ + RuntimeError + If splitting the epoch with `interval_size` results in an empty set. + TypeError + If `ep` or `sig` are not respectively pynapple time series or interval set. + """ + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") + if fs is None: + fs = sig.rate + + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") + + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + + # Split the ep + interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ + 0 + ] + split_ep = ep.split(interval_size) + + if len(split_ep) == 0: + raise RuntimeError( + f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size" + ) + + # Get the slices of each ep + slices = np.zeros((len(split_ep), 2), dtype=int) + + for i in range(len(split_ep)): + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) + slices[i, 0] = sl.start + slices[i, 1] = sl.stop + + # Check what is the signal length + N = np.min(np.diff(slices, 1)) + + if N == 0: + raise RuntimeError( + "One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed." + ) + + # Get the freqs + fft_freq = np.fft.fftfreq(N, 1 / fs) + + # Get the Hamming window + window = signal.windows.hamming(N) + if sig.ndim == 2: + window = window[:, np.newaxis] + + # Compute the fft + fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) + + for i in range(len(slices)): + tmp = sig[slices[i, 0] : slices[i, 1]].values[0:N] * window + fft_result += np.fft.fft(tmp, axis=0) + + if norm: + fft_result = fft_result / (float(N) * float(len(slices))) + + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + if not full_range: + return ret.loc[ret.index >= 0] + return ret diff --git a/pynapple/process/wavelets.py b/pynapple/process/wavelets.py new file mode 100644 index 00000000..e8aa601c --- /dev/null +++ b/pynapple/process/wavelets.py @@ -0,0 +1,236 @@ +""" +# Wavelets decomposition + +The main function for doing wavelet decomposition is `nap.compute_wavelet_transform` + +For now, pynapple only implements Morlet wavelets. To check the shape and quality of the wavelets, check out +the function `nap.generate_morlet_filterbank` to plot the wavelets. + +""" + +import numpy as np + +from .. import core as nap + + +def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): + """ + Defines the complex Morlet wavelet kernel. + + Parameters + ---------- + M : int + Length of the wavelet + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. + precision: int. + Precision of wavelet to use. Default is 8 + + Returns + ------- + np.ndarray + Morelet wavelet kernel + """ + x = np.linspace(-precision, precision, M) + return ( + ((np.pi * gaussian_width) ** (-0.25)) + * np.exp(-(x**2) / gaussian_width) + * np.exp(1j * 2 * np.pi * window_length * x) + ) + + +def compute_wavelet_transform( + sig, freqs, fs=None, gaussian_width=1.5, window_length=1.0, precision=16, norm="l1" +): + """ + Compute the time-frequency representation of a signal using Morlet wavelets. + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame or pynapple.TsdTensor + Time series. + freqs : 1d array + Frequency values to estimate with Morlet wavelets. + fs : float or None + Sampling rate, in Hz. Defaults to `sig.rate` if None is given. + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. Default is 1.5. + window_length : float + The length of window to be used for wavelet creation. Default is 1.0. + precision: int. + Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. + Default is 16. + norm : {None, 'l1', 'l2'}, optional + Normalization method: + - None - no normalization + - 'l1' - (default) divide by the sum of amplitudes + - 'l2' - divide by the square root of the sum of amplitudes + + Returns + ------- + pynapple.TsdFrame or pynapple.TsdTensor + Time frequency representation of the input signal. + + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.arange(0, 1, 1/1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> freqs = np.linspace(10, 100, 10) + >>> mwt = nap.compute_wavelet_transform(signal, fs=1000, freqs=freqs) + + Notes + ----- + This computes the continuous wavelet transform at specified frequencies across time. + """ + + if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): + raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") + + if not isinstance(freqs, np.ndarray): + raise TypeError("`freqs` must be a ndarray") + if len(freqs) == 0: + raise ValueError("Given list of freqs cannot be empty.") + if np.min(freqs) <= 0: + raise ValueError("All frequencies in freqs must be strictly positive") + + if fs is not None and not isinstance(fs, (int, float, np.number)): + raise TypeError("`fs` must be of type float or int or None") + + if norm is not None and norm not in ["l1", "l2"]: + raise ValueError("norm parameter must be 'l1', 'l2', or None.") + + if fs is None: + fs = sig.rate + + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = np.reshape(sig, (sig.shape[0], -1)) + + filter_bank = generate_morlet_filterbank( + freqs, fs, gaussian_width, window_length, precision + ) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + + if norm == "l1": + coef = convolved / (fs / freqs) + elif norm == "l2": + coef = convolved / (fs / np.sqrt(freqs)) + else: + coef = convolved + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + + if len(output_shape) == 2: + return nap.TsdFrame( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + + return nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + + +def generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 +): + """ + Generates a Morlet filterbank using the given frequencies and parameters. + + This function can be used purely for visualization, or to convolve with a pynapple Tsd, + TsdFrame, or TsdTensor as part of a wavelet decomposition process. + + Parameters + ---------- + freqs : 1d array + frequency values to estimate with Morlet wavelets. + fs : float or int + Sampling rate, in Hz. + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. + precision: int. + Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. + + Returns + ------- + filter_bank : pynapple.TsdFrame + list of Morlet wavelet filters of the frequencies given + + Notes + ----- + This algorithm first computes a single, finely sampled wavelet using the provided hyperparameters. + Wavelets of different frequencies are generated by resampling this mother wavelet with an appropriate step size. + The step size is determined based on the desired frequency and the sampling rate. + """ + if not isinstance(freqs, np.ndarray): + raise TypeError("`freqs` must be a ndarray") + if len(freqs) == 0: + raise ValueError("Given list of freqs cannot be empty.") + if np.min(freqs) <= 0: + raise ValueError("All frequencies in freqs must be strictly positive") + + if not isinstance(fs, (int, float, np.number)): + raise TypeError("`fs` must be of type float or int ndarray") + + if isinstance(gaussian_width, (int, float, np.number)): + if gaussian_width <= 0: + raise ValueError("gaussian_width must be a positive number.") + else: + raise TypeError("gaussian_width must be a float or int instance.") + + if isinstance(window_length, (int, float, np.number)): + if window_length <= 0: + raise ValueError("window_length must be a positive number.") + else: + raise TypeError("window_length must be a float or int instance.") + + if isinstance(precision, int): + if precision <= 0: + raise ValueError("precision must be a positive number.") + else: + raise TypeError("precision must be a float or int instance.") + + # Initialize filter bank and parameters + filter_bank = [] + cutoff = 8 # Define cutoff for wavelet + # Compute a single, finely sampled Morlet wavelet + morlet_f = np.conj( + _morlet( + int(2**precision), + gaussian_width=gaussian_width, + window_length=window_length, + ) + ) + x = np.linspace(-cutoff, cutoff, int(2**precision)) + max_len = -1 # Track maximum length of wavelet + for freq in freqs: + scale = window_length / (freq / fs) + # Calculate the indices for subsampling the wavelet and achieve the right frequency + # After the slicing the size will be reduced, therefore we will pad with 0s. + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) + j = np.ceil(j).astype(int) # Ceil the values to get integer indices + if j[-1] >= morlet_f.size: + j = np.extract(j < morlet_f.size, j) + scaled_morlet = morlet_f[j][::-1] # Scale and reverse wavelet + if len(scaled_morlet) > max_len: + max_len = len(scaled_morlet) + time = np.linspace( + -cutoff * window_length / freq, cutoff * window_length / freq, max_len + ) + filter_bank.append(scaled_morlet) + # Pad wavelets to ensure all are of the same length + filter_bank = [ + np.pad( + arr, + ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), + constant_values=0.0, + ) + for arr in filter_bank + ] + # Return filter bank as a TsdFrame + return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) diff --git a/pyproject.toml b/pyproject.toml index 52d8682d..73afed77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pynapple" -version = "0.6.6" +version = "0.7.0" description = "PYthon Neural Analysis Package Pour Laboratoires d’Excellence" readme = "README.md" authors = [{ name = "Guillaume Viejo", email = "guillaume.viejo@gmail.com" }] @@ -36,8 +36,8 @@ requires-python = ">=3.8" include = ["pynapple", "pynapple.*"] [project.urls] -homepage = "https://github.com/pynapple-org/pynapple" -documentation = "https://pynapple-org.github.io/pynapple/" +homepage = "http://pynapple.org/" +documentation = "http://pynapple.org/" repository = "https://github.com/pynapple-org/pynapple" ########################################################################## diff --git a/setup.py b/setup.py index b4fbb0c3..d8fc430f 100644 --- a/setup.py +++ b/setup.py @@ -59,8 +59,8 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/pynapple-org/pynapple', - version='v0.6.6', + version='v0.7.0', zip_safe=False, long_description_content_type='text/markdown', - download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.6.6.tar.gz' + download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.7.0.tar.gz' ) diff --git a/tests/test_filtering.py b/tests/test_filtering.py new file mode 100644 index 00000000..19461760 --- /dev/null +++ b/tests/test_filtering.py @@ -0,0 +1,301 @@ +import pytest +import pynapple as nap +import numpy as np +from scipy import signal +import pandas as pd +import warnings +from contextlib import nullcontext as does_not_raise + + +# @pytest.fixture +def sample_data(): + # Create a sample Tsd data object + t = np.linspace(0, 1, 500) + d = np.sin(2 * np.pi * 10 * t) + np.random.normal(0, 0.5, t.shape) + time_support = nap.IntervalSet(start=[0], end=[1]) + return nap.Tsd(t=t, d=d, time_support=time_support) + + +def sample_data_with_nan(): + # Create a sample Tsd data object + t = np.linspace(0, 1, 500) + d = np.sin(2 * np.pi * 10 * t) + np.random.normal(0, 0.5, t.shape) + d[10] = np.nan + time_support = nap.IntervalSet(start=[0], end=[1]) + return nap.Tsd(t=t, d=d, time_support=time_support) + + +def compare_scipy(tsd, ep, order, freq, fs, btype): + sos = signal.butter(order, freq, btype=btype, fs=fs, output="sos") + out_sci = [] + for iset in ep: + out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0)) + out_sci = np.concatenate(out_sci, axis=0) + return out_sci + +def compare_sinc(tsd, ep, transition_bandwidth, freq, fs, ftype): + + kernel = nap.process.filtering._get_windowed_sinc_kernel(freq, ftype, fs, transition_bandwidth) + return tsd.convolve(kernel, ep).d + + +@pytest.mark.parametrize("freq", [10]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_low_pass(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + + out = nap.apply_lowpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "lowpass") + np.testing.assert_array_almost_equal(out.d, out_sci) + + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "lowpass") + np.testing.assert_array_almost_equal(out.d, out_sinc) + + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + +@pytest.mark.parametrize("freq", [10]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_high_pass(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + + out = nap.apply_highpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "highpass") + np.testing.assert_array_almost_equal(out.d, out_sinc) + + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "highpass") + np.testing.assert_array_almost_equal(out.d, out_sci) + + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + +@pytest.mark.parametrize("freq", [[10, 30]]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_bandpass(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + + out = nap.apply_bandpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandpass") + np.testing.assert_array_almost_equal(out.d, out_sinc) + + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "bandpass") + np.testing.assert_array_almost_equal(out.d, out_sci) + + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + +@pytest.mark.parametrize("freq", [[10, 30]]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [2, 4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_bandstop(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + + out = nap.apply_bandstop_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandstop") + np.testing.assert_array_almost_equal(out.d, out_sinc) + + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "bandstop") + np.testing.assert_array_almost_equal(out.d, out_sci) + + + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + +######################################################################## +# Errors +######################################################################## +@pytest.mark.parametrize("func, freq", [ + (nap.apply_lowpass_filter, 10), + (nap.apply_highpass_filter, 10), + (nap.apply_bandpass_filter, [10, 20]), + (nap.apply_bandstop_filter, [10, 20]), +]) +@pytest.mark.parametrize("data, fs, mode, order, transition_bandwidth, expected_exception", [ + (sample_data(), None, "butter", "a", 0.02, pytest.raises(ValueError,match="Invalid value for 'order': Parameter 'order' should be of type int")), + ("invalid_data", None, "butter", 4, 0.02, pytest.raises(ValueError,match="Invalid value: invalid_data. First argument should be of type Tsd, TsdFrame or TsdTensor")), + (sample_data(), None, "invalid_mode", 4, 0.02, pytest.raises(ValueError,match="Unrecognized filter mode. Choose either 'butter' or 'sinc'")), + (sample_data(), "invalid_fs", "butter", 4, 0.02, pytest.raises(ValueError,match="Invalid value for 'fs'. Parameter 'fs' should be of type float or int")), + (sample_data(), None, "sinc", 4, "a", pytest.raises(ValueError,match="Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float")), + (sample_data_with_nan(), None, "sinc", 4, 0.02, pytest.raises(ValueError,match="The input signal contains NaN values, which are not supported for filtering")), + (sample_data_with_nan(), None, "butter", 4, 0.02, pytest.raises(ValueError, match="The input signal contains NaN values, which are not supported for filtering")) + +]) +def test_compute_filtered_signal_raise_errors(func, freq, data, fs, mode, order, transition_bandwidth, expected_exception): + with expected_exception: + func(data, freq, fs=fs, mode=mode, order=order, transition_bandwidth=transition_bandwidth) + +@pytest.mark.parametrize("func, freq, expected_exception", [ + (nap.apply_lowpass_filter, "a", pytest.raises(ValueError,match=r"lowpass filter require a single number. a provided instead.")), + (nap.apply_highpass_filter, "b", pytest.raises(ValueError,match=r"highpass filter require a single number. b provided instead.")), + (nap.apply_bandpass_filter, [10, "b"], pytest.raises(ValueError,match="bandpass filter require a tuple of two numbers. \[10, 'b'\] provided instead.")), + (nap.apply_bandstop_filter, [10, 20, 30], pytest.raises(ValueError,match=r"bandstop filter require a tuple of two numbers. \[10, 20, 30\] provided instead.")) +]) +def test_compute_filtered_signal_bad_freq(func, freq, expected_exception): + with expected_exception: + func(sample_data(), freq) + + +################################################################# +# Test with edge-case frequencies close to Nyquist frequency +@pytest.mark.parametrize("nyquist_fraction", [0.99, 0.999]) +@pytest.mark.parametrize("order", [2, 4]) +def test_filtering_nyquist_edge_case(nyquist_fraction, order): + data = sample_data() + nyquist_freq = 0.5 * data.rate + freq = nyquist_freq * nyquist_fraction + + out = nap.filtering.apply_lowpass_filter(data, freq, order=order) + assert isinstance(out, type(data)) + np.testing.assert_allclose(out.t, data.t) + np.testing.assert_allclose(out.time_support, data.time_support) + +################################################################# +# Test windowedsinc kernel + +@pytest.mark.parametrize("tb", [0.2, 0.3]) +def test_get_odd_kernel(tb): + kernel = nap.process.filtering._get_windowed_sinc_kernel(1, "lowpass", 4, transition_bandwidth=tb) + assert len(kernel)%2 != 0 + +@pytest.mark.parametrize("filter_type, expected_exception", [ + ("a", pytest.raises(ValueError)), +]) +def test_get_kernel_error(filter_type, expected_exception): + with expected_exception: + nap.process.filtering._get_windowed_sinc_kernel(1, filter_type, 4) + +def test_get__error(): + with pytest.raises(TypeError, match=r"apply_lowpass_filter\(\) missing 1 required positional argument: 'data'"): + nap.apply_lowpass_filter(cutoff=0.25) + + +def test_compare_sinc_kernel(): + kernel = nap.process.filtering._get_windowed_sinc_kernel(1, "lowpass", 4) + x = np.arange(-(len(kernel)//2), 1+len(kernel)//2) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + kernel2 = np.sin(2*np.pi*x*0.25)/x#(2*np.pi*x*0.25) + kernel2[len(kernel)//2] = 0.25*2*np.pi + kernel2 = kernel2 + kernel2 = kernel2 * np.blackman(len(kernel2)) + kernel2 /= kernel2.sum() + np.testing.assert_allclose(kernel, kernel2) + + ikernel = nap.process.filtering._compute_spectral_inversion(kernel) + ikernel2 = kernel2 * -1.0 + ikernel2[len(ikernel2) // 2] = 1.0 + ikernel2[len(kernel2) // 2] + np.testing.assert_allclose(ikernel, ikernel2) + +@pytest.mark.parametrize("cutoff, fs, filter_type, mode, order, tb", [ + (250, 1000, "lowpass", "butter", 4, 0.02), + (250, 1000, "lowpass", "sinc", 4, 0.02), +]) +def test_get_filter_frequency_response(cutoff, fs, filter_type, mode, order, tb): + output = nap.get_filter_frequency_response(cutoff, fs, filter_type, mode, order, tb) + assert isinstance(output, pd.Series) + if mode == "butter": + sos = nap.process.filtering._get_butter_coefficients(cutoff, filter_type, fs, order) + w, h = signal.sosfreqz(sos, worN=1024, fs=fs) + np.testing.assert_array_almost_equal(w, output.index.values) + np.testing.assert_array_almost_equal(np.abs(h), output.values) + if mode == "sinc": + kernel = nap.process.filtering._get_windowed_sinc_kernel(cutoff, filter_type, fs, tb) + fft_result = np.fft.fft(kernel) + fft_result = np.fft.fftshift(fft_result) + fft_freq = np.fft.fftfreq(n=len(kernel), d=1 / fs) + fft_freq = np.fft.fftshift(fft_freq) + np.testing.assert_array_almost_equal(fft_freq[fft_freq >= 0], output.index.values) + np.testing.assert_array_almost_equal(np.abs(fft_result[fft_freq >= 0]), output.values) + +def test_get_filter_frequency_response_error(): + with pytest.raises(ValueError, match="Unrecognized filter mode. Choose either 'butter' or 'sinc'"): + nap.get_filter_frequency_response(250, 1000, "lowpass", "a", 4, 0.02) diff --git a/tests/test_folder.py b/tests/test_folder.py index 21c2da72..4037eb71 100644 --- a/tests/test_folder.py +++ b/tests/test_folder.py @@ -11,25 +11,19 @@ import pandas as pd import pytest import warnings -import os +from pathlib import Path +import shutil import json # look for tests folder -path = os.getcwd() -if os.path.basename(path) == 'pynapple': - path = os.path.join(path, "tests") - -path = os.path.join(path, "npzfilestest") -if not os.path.isdir(path): - os.mkdir(path) -# path2 = os.path.join(path, "sub") -# if not os.path.isdir(path): -# os.mkdir(path2) - -# Cleaning -for root, dirs, files in os.walk(path): - for f in files: - os.remove(os.path.join(root, f)) +path = Path(__file__).parent +if path.name == 'pynapple': + path = path / "tests" +path = path / "npzfilestest" + +# Recursively remove the folder: +shutil.rmtree(path, ignore_errors=True) +path.mkdir(exist_ok=True, parents=True) # Populate the folder data = { @@ -45,9 +39,7 @@ } for k, d in data.items(): - d.save(os.path.join(path, k+".npz")) -# for k, d in data.items(): -# d.save(os.path.join(path, "sub", k+".npz")) + d.save(path / (k+".npz")) @pytest.mark.parametrize("path", [path]) def test_load_folder(path): @@ -78,11 +70,11 @@ def test_save(folder): assert isinstance(folder['tsd2'], nap.Tsd) - files = os.listdir(folder.path) + files = [f.name for f in path.iterdir()] assert "tsd2.json" in files # check json - metadata = json.load(open(os.path.join(path, "tsd2.json"), "r")) + metadata = json.load(open(path / "tsd2.json", "r")) assert "time" in metadata.keys() assert "info" in metadata.keys() assert "Test description" == metadata["info"] @@ -102,18 +94,3 @@ def test_load(path): folder.load() for k in data.keys(): assert type(folder[k]) == type(data[k]) - - - - - - - - - - - - - - - diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index bd4b5631..ec7105d0 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -6,6 +6,7 @@ import pytest import warnings from .mock import MockArray +from pathlib import Path def test_create_iset(): @@ -78,6 +79,23 @@ def test_create_iset_from_mock_array(): np.testing.assert_array_almost_equal(ep.start, start) np.testing.assert_array_almost_equal(ep.end, end) +def test_create_iset_from_tuple(): + start = 0 + end = 5 + ep = nap.IntervalSet((start, end)) + assert isinstance(ep, nap.core.interval_set.IntervalSet) + np.testing.assert_array_almost_equal(start, ep.start[0]) + np.testing.assert_array_almost_equal(end, ep.end[0]) + +def test_create_iset_from_tuple_iter(): + start = [0, 10, 16, 25] + end = [5, 15, 20, 40] + pairs = zip(start, end) + ep = nap.IntervalSet(pairs) + assert isinstance(ep, nap.core.interval_set.IntervalSet) + np.testing.assert_array_almost_equal(start, ep.start) + np.testing.assert_array_almost_equal(end, ep.end) + def test_create_iset_from_unknown_format(): with pytest.raises(RuntimeError) as e: nap.IntervalSet(start="abc", end=[1, 2]) @@ -259,12 +277,16 @@ def test_tot_length(): def test_as_units(): ep = nap.IntervalSet(start=0, end=100) - df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"]) - pd.testing.assert_frame_equal(df, ep.as_units("s")) - pd.testing.assert_frame_equal(df * 1e3, ep.as_units("ms")) + df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"], dtype=np.float64) + np.testing.assert_array_almost_equal(df.values, ep.as_units("s").values.astype(np.float64)) + np.testing.assert_array_almost_equal(df * 1e3, ep.as_units("ms").values.astype(np.float64)) tmp = df * 1e6 - np.testing.assert_array_almost_equal(tmp.values, ep.as_units("us").values) + np.testing.assert_array_almost_equal(tmp.values, ep.as_units("us").values.astype(np.float64)) +def test_as_dataframe(): + ep = nap.IntervalSet(start=0, end=100) + df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"], dtype=np.float64) + np.testing.assert_array_almost_equal(df.values, ep.as_dataframe().values) def test_intersect(): ep = nap.IntervalSet(start=[0, 30], end=[10, 70]) @@ -474,48 +496,86 @@ def test_str_(): assert isinstance(ep.__str__(), str) def test_save_npz(): - import os start = np.around(np.array([0, 10, 16], dtype=np.float64), 9) end = np.around(np.array([5, 15, 20], dtype=np.float64), 9) ep = nap.IntervalSet(start=start,end=end) - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: ep.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: ep.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: ep.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) ep.save("ep.npz") - os.listdir('.') - assert "ep.npz" in os.listdir(".") + assert "ep.npz" in [f.name for f in Path('.').iterdir()] ep.save("ep2") - os.listdir('.') - assert "ep2.npz" in os.listdir(".") + assert "ep2.npz" in [f.name for f in Path('.').iterdir()] - file = np.load("ep.npz") + with np.load("ep.npz") as file: - keys = list(file.keys()) - assert 'start' in keys - assert 'end' in keys + keys = list(file.keys()) + assert 'start' in keys + assert 'end' in keys - np.testing.assert_array_almost_equal(file['start'], start) - np.testing.assert_array_almost_equal(file['end'], end) + np.testing.assert_array_almost_equal(file['start'], start) + np.testing.assert_array_almost_equal(file['end'], end) # Cleaning - os.remove("ep.npz") - os.remove("ep2.npz") - + Path("ep.npz").unlink() + Path("ep2.npz").unlink() + +def test_split(): + np.random.seed(0) + start = np.round(np.random.uniform(0, 10)) + end = np.round(np.random.uniform(90, 100)) + tmp = np.linspace(start, end, 100) + interval_size = np.round(tmp[1] - tmp[0], 9) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ep0 = nap.IntervalSet(tmp[0:-1], tmp[1:]) + ep = nap.IntervalSet(tmp[0], tmp[-1]) + ep1 = ep.split(interval_size) + np.testing.assert_array_almost_equal(ep0, ep1) + + # Test with a smaller epochs + start = np.hstack((tmp[0:-1], np.array([200]))) + end = np.hstack((tmp[1:], np.array([200+0.9*interval_size]))) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ep2 = nap.IntervalSet(start, end) + + ep = nap.IntervalSet([start[0], 200], end[-2:]) + ep1 = ep.split(interval_size) + np.testing.assert_array_almost_equal(ep0, ep1) + + # Empty intervalset + ep = nap.IntervalSet([], []) + assert len(ep.split(1)) == 0 + +def test_split_errors(): + start = [0, 10, 16, 25] + end = [5, 15, 20, 40] + ep = nap.IntervalSet(start=start, end=end) + with pytest.raises(IOError, match="Argument interval_size should of type float or int"): + ep.split('a') + with pytest.raises(IOError) as e: + ep.split(0) + assert str(e.value) == "Argument interval_size should be strictly larger than 0" + + with pytest.raises(IOError) as e: + ep.split(1, time_units=1) + assert str(e.value) == "Argument time_units should be of type str" + diff --git a/tests/test_jitted.py b/tests/test_jitted.py index f97096d5..5b8a4aa1 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -52,9 +52,9 @@ def restrict(ep, tsd): ) ) ix3 = np.vstack((ix, ix2)).T - # ix[np.floor(ix / 2) * 2 != ix] = np.NaN + # ix[np.floor(ix / 2) * 2 != ix] = np.nan # ix = np.floor(ix/2) - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN + ix3[np.floor(ix3 / 2) * 2 != ix3] = np.nan ix3 = np.floor(ix3 / 2) ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] @@ -71,8 +71,9 @@ def test_jitrestrict(): tsd2 = restrict(ep, tsd) ix = nap.core._jitted_functions.jitrestrict(tsd.index, ep.start, ep.end) - tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) - pd.testing.assert_series_equal(tsd2, tsd3) + tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) def test_jitrestrict_with_count(): for i in range(100): @@ -81,13 +82,15 @@ def test_jitrestrict_with_count(): tsd2 = restrict(ep, tsd) ix, count = nap.core._jitted_functions.jitrestrict_with_count(tsd.index, ep.start, ep.end) tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) - pd.testing.assert_series_equal(tsd2, tsd3) + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) + bins = ep.values.ravel() ix = np.array(pd.cut(tsd.index, bins, labels=np.arange(len(bins) - 1, dtype=np.float64))) ix2 = np.array(pd.cut(tsd.index,bins,labels=np.arange(len(bins) - 1, dtype=np.float64),right=False,)) ix3 = np.vstack((ix, ix2)).T - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN + ix3[np.floor(ix3 / 2) * 2 != ix3] = np.nan ix3 = np.floor(ix3 / 2) ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] ix = ix3[:,0] @@ -146,8 +149,10 @@ def test_jitvalue_from(): tsd2.append(tsd.restrict(ep[j]).as_series().reindex(ix, method="nearest").fillna(0.0)) tsd2 = pd.concat(tsd2) + + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) - pd.testing.assert_series_equal(tsd2, tsd3) def test_jitcount(): for i in range(10): @@ -157,7 +162,7 @@ def test_jitcount(): starts = ep.start ends = ep.end bin_size = 1.0 - t, d = nap.core._jitted_functions.jitcount(time_array, starts, ends, bin_size) + t, d = nap.core._jitted_functions.jitcount(time_array, starts, ends, bin_size, np.int64) tsd3 = nap.Tsd(t=t, d=d, time_support = ep) tsd2 = [] @@ -166,15 +171,15 @@ def test_jitcount(): idx = np.digitize(ts.restrict(ep[j]).index, bins)-1 tmp = np.array([np.sum(idx==j) for j in range(len(bins)-1)]) tmp = nap.Tsd(t = bins[0:-1] + np.diff(bins)/2, d = tmp) - tmp = tmp.restrict(ep[j]) - - # pd.testing.assert_series_equal(tmp, tsd3.restrict(ep.loc[[j]])) + tmp = tmp.restrict(ep[j]) tsd2.append(tmp.as_series()) tsd2 = pd.concat(tsd2) - - pd.testing.assert_series_equal(tsd3.as_series(), tsd2) + + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) + def test_jitbin(): for i in range(10): @@ -210,8 +215,10 @@ def test_jitbin(): tsd2 = pd.concat(tsd2) # tsd2 = nap.Tsd(tsd2) tsd2 = tsd2.fillna(0.0) + + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) - pd.testing.assert_series_equal(tsd3, tsd2) def test_jitbin_array(): for i in range(10): @@ -248,7 +255,8 @@ def test_jitbin_array(): tsd2 = pd.concat(tsd2) # tsd2 = nap.TsdFrame(tsd2) - pd.testing.assert_frame_equal(tsd3, tsd2) + np.testing.assert_array_almost_equal(tsd3.values, tsd2.values) + np.testing.assert_array_almost_equal(tsd3.index.values, tsd2.index.values) def test_jitintersect(): for i in range(10): @@ -280,7 +288,6 @@ def test_jitintersect(): ep4 = nap.IntervalSet(start, end) - # pd.testing.assert_frame_equal(ep3, ep4) np.testing.assert_array_almost_equal(ep3, ep4) def test_jitunion(): @@ -313,8 +320,7 @@ def test_jitunion(): stop = df["time"][ix_stop] ep4 = nap.IntervalSet(start, stop) - - # pd.testing.assert_frame_equal(ep3, ep4) + np.testing.assert_array_almost_equal(ep3, ep4) def test_jitdiff(): @@ -353,8 +359,7 @@ def test_jitdiff(): idx = start != end ep4 = nap.IntervalSet(start[idx], end[idx]) - - # pd.testing.assert_frame_equal(ep3, ep4) + np.testing.assert_array_almost_equal(ep3, ep4) def test_jitunion_isets(): @@ -388,8 +393,7 @@ def test_jitunion_isets(): stop = df["time"][ix_stop] ep5 = nap.IntervalSet(start, stop) - - # pd.testing.assert_frame_equal(ep5, ep6) + np.testing.assert_array_almost_equal(ep5, ep6) @@ -413,7 +417,7 @@ def test_jitin_interval(): ) ) ix3 = np.vstack((ix, ix2)).T - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN + ix3[np.floor(ix3 / 2) * 2 != ix3] = np.nan ix3 = np.floor(ix3 / 2) ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] inep2 = ix3[:, 0] diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 96c30bc9..b140bf0e 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -1,4 +1,3 @@ -import os.path import warnings from contextlib import nullcontext as does_not_raise from pathlib import Path @@ -22,18 +21,18 @@ (np.arange(12), "not_an_array", pytest.raises(TypeError, match="Data should be array-like")) ] ) -def test_lazy_load_hdf5_is_array(time, data, expectation): - file_path = Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - h5_data = h5py.File(file_path, 'r')["data"] +def test_lazy_load_hdf5_is_array(time, data, expectation, tmp_path): + file_path = tmp_path / Path('data.h5') + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + with h5py.File(file_path, 'r') as h5_data: with expectation: - nap.Tsd(t=time, d=h5_data, load_array=False) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + nap.Tsd(t=time, d=h5_data['data'], load_array=False) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize( @@ -43,47 +42,47 @@ def test_lazy_load_hdf5_is_array(time, data, expectation): ] ) @pytest.mark.parametrize("convert_flag", [True, False]) -def test_lazy_load_hdf5_is_array(time, data, convert_flag): - file_path = Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - tsd = nap.Tsd(t=time, d=h5_data, load_array=convert_flag) - if convert_flag: - assert not isinstance(tsd.d, h5py.Dataset) - else: - assert isinstance(tsd.d, h5py.Dataset) - finally: - # delete file - if file_path.exists(): - file_path.unlink() +def test_lazy_load_hdf5_is_array(time, data, convert_flag, tmp_path): + file_path = tmp_path / Path('data.h5') + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + tsd = nap.Tsd(t=time, d=h5_data, load_array=convert_flag) + if convert_flag: + assert not isinstance(tsd.d, h5py.Dataset) + else: + assert isinstance(tsd.d, h5py.Dataset) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @pytest.mark.parametrize("cls", [nap.Tsd, nap.TsdFrame, nap.TsdTensor]) @pytest.mark.parametrize("func", [np.exp, lambda x: x*2]) -def test_lazy_load_hdf5_apply_func(time, data, func,cls): +def test_lazy_load_hdf5_apply_func(time, data, func,cls, tmp_path): """Apply a unary function to a lazy loaded array.""" - file_path = Path('data.h5') - try: - if cls is nap.TsdFrame: - data = data[:, None] - elif cls is nap.TsdTensor: - data = data[:, None, None] - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - res = func(cls(t=time, d=h5_data, load_array=False)) - assert isinstance(res, cls) - assert not isinstance(res.d, h5py.Dataset) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + file_path = tmp_path / Path('data.h5') + # try: + if cls is nap.TsdFrame: + data = data[:, None] + elif cls is nap.TsdTensor: + data = data[:, None, None] + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + res = func(cls(t=time, d=h5_data, load_array=False)) + assert isinstance(res, cls) + assert not isinstance(res.d, h5py.Dataset) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @@ -102,26 +101,26 @@ def test_lazy_load_hdf5_apply_func(time, data, func,cls): ("get", [2, 7]) ] ) -def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls): - file_path = Path('data.h5') - try: - if cls is nap.TsdFrame: - data = data[:, None] - elif cls is nap.TsdTensor: - data = data[:, None, None] - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = cls(t=time, d=h5_data, load_array=False) - func = getattr(tsd, method_name) - out = func(*args) - assert not isinstance(out.d, h5py.Dataset) - finally: - # delete file - if file_path.exists(): - file_path.unlink() +def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls, tmp_path): + file_path = tmp_path / Path('data.h5') + # try: + if cls is nap.TsdFrame: + data = data[:, None] + elif cls is nap.TsdTensor: + data = data[:, None, None] + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = cls(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + out = func(*args) + assert not isinstance(out.d, h5py.Dataset) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @@ -134,21 +133,21 @@ def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls): ("to_tsgroup", [], nap.TsGroup) ] ) -def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, expected_out_type): - file_path = Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = nap.Tsd(t=time, d=h5_data, load_array=False) - func = getattr(tsd, method_name) - assert isinstance(func(*args), expected_out_type) - finally: - # delete file - if file_path.exists(): - file_path.unlink() +def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, expected_out_type, tmp_path): + file_path = tmp_path / Path('data.h5') + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.Tsd(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + assert isinstance(func(*args), expected_out_type) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @@ -158,40 +157,40 @@ def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, ("as_dataframe", [], pd.DataFrame), ] ) -def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, args, expected_out_type): - file_path = Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data[:, None]) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = nap.TsdFrame(t=time, d=h5_data, load_array=False) - func = getattr(tsd, method_name) - assert isinstance(func(*args), expected_out_type) - finally: - # delete file - if file_path.exists(): - file_path.unlink() - - -def test_lazy_load_hdf5_tsdframe_loc(): - file_path = Path('data.h5') +def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, args, expected_out_type, tmp_path): + file_path = tmp_path / Path('data.h5') + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data[:, None]) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.TsdFrame(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + assert isinstance(func(*args), expected_out_type) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() + + +def test_lazy_load_hdf5_tsdframe_loc(tmp_path): + file_path = tmp_path / Path('data.h5') data = np.arange(10).reshape(5, 2) - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = nap.TsdFrame(t=np.arange(data.shape[0]), d=h5_data, load_array=False).loc[1] - assert isinstance(tsd, nap.Tsd) - assert all(tsd.d == np.array([1, 3, 5, 7, 9])) - - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.TsdFrame(t=np.arange(data.shape[0]), d=h5_data, load_array=False).loc[1] + assert isinstance(tsd, nap.Tsd) + assert all(tsd.d == np.array([1, 3, 5, 7, 9])) + + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize( "lazy", @@ -206,68 +205,81 @@ def test_lazy_load_nwb(lazy): except: nwb = nap.NWBFile("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) - tsd = nwb["z"] - if lazy: - assert isinstance(tsd.d, h5py.Dataset) - else: - assert not isinstance(tsd.d, h5py.Dataset) - nwb.io.close() + assert isinstance(nwb["z"].d, h5py.Dataset) is lazy + nwb.close() -@pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))]) -def test_lazy_load_nwb_no_warnings(data): - file_path = Path('data.h5') - - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - time_series = mock_TimeSeries(name="TimeSeries", data=f["data"]) - nwbfile = mock_NWBFile() - nwbfile.add_acquisition(time_series) - nwb = nap.NWBFile(nwbfile) - - with warnings.catch_warnings(record=True) as w: - tsd = nwb["TimeSeries"] - tsd.count(0.1) - assert isinstance(tsd.d, h5py.Dataset) - - if len(w): - if not str(w[0].message).startswith("Converting 'd' to"): - raise RuntimeError - - finally: - if file_path.exists(): - file_path.unlink() - - -def test_tsgroup_no_warnings(): - n_units = 2 +@pytest.mark.parametrize( + "lazy", + [ + (True), + (False), + ] +) +def test_lazy_load_function(lazy): try: - for k in range(n_units): - file_path = Path(f'data_{k}.h5') - with h5py.File(file_path, 'w') as f: - f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20))) - with warnings.catch_warnings(record=True) as w: - - nwbfile = mock_NWBFile() - - for k in range(n_units): - file_path = Path(f'data_{k}.h5') - spike_times = h5py.File(file_path, "r")['spks'] - nwbfile.add_unit(spike_times=spike_times) - - nwb = nap.NWBFile(nwbfile) - tsgroup = nwb["units"] - tsgroup.count(0.1) - + nwb = nap.load_file("tests/nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) + except: + nwb = nap.load_file("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) + + assert isinstance(nwb["z"].d, h5py.Dataset) is lazy + nwb.close() + + +@pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))]) +def test_lazy_load_nwb_no_warnings(data, tmp_path): # tmp_path is a default fixture creating a temporary folder + file_path = tmp_path / Path('data.h5') + + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + time_series = mock_TimeSeries(name="TimeSeries", data=f["data"]) + nwbfile = mock_NWBFile() + nwbfile.add_acquisition(time_series) + nwb = nap.NWBFile(nwbfile) + + with warnings.catch_warnings(record=True) as w: + tsd = nwb["TimeSeries"] + tsd.count(0.1) + assert isinstance(tsd.d, h5py.Dataset) + if len(w): if not str(w[0].message).startswith("Converting 'd' to"): raise RuntimeError + # finally: + # if file_path.exists(): + # file_path.unlink() + + +def test_tsgroup_no_warnings(tmp_path): # default fixture + n_units = 2 + # try: + for k in range(n_units): + file_path = tmp_path / Path(f'data_{k}.h5') + with h5py.File(file_path, 'w') as f: + f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20))) + with warnings.catch_warnings(record=True) as w: + + nwbfile = mock_NWBFile() - finally: for k in range(n_units): - file_path = Path(f'data_{k}.h5') - if file_path.exists(): - file_path.unlink() + file_path = tmp_path / Path(f'data_{k}.h5') + spike_times = h5py.File(file_path, "r")['spks'] + nwbfile.add_unit(spike_times=spike_times) + + nwb = nap.NWBFile(nwbfile) + tsgroup = nwb["units"] + tsgroup.count(0.1) + + if len(w): + if not str(w[0].message).startswith("Converting 'd' to"): + raise RuntimeError + + + # finally: + # for k in range(n_units): + # file_path = Path(f'data_{k}.h5') + # if file_path.exists(): + # file_path.unlink() diff --git a/tests/test_misc.py b/tests/test_misc.py index 47024e4b..61c12b56 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2023-07-10 12:26:20 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 16:05:24 +# @Last Modified time: 2024-07-31 11:17:59 """Tests of IO misc functions""" @@ -11,24 +11,27 @@ import pandas as pd import pytest import warnings -import os +from pathlib import Path +import shutil # look for tests folder -path = os.getcwd() -if os.path.basename(path) == 'pynapple': - path = os.path.join(path, "tests") +path = Path(__file__).parent +if path.name == 'pynapple': + path = path / "tests" +path = path / "npzfilestest" + +# Recursively remove the folder: +shutil.rmtree(path, ignore_errors=True) +path.mkdir(exist_ok=True, parents=True) + +path2 = path.parent / "sub" +path2.mkdir(exist_ok=True, parents=True) -path = os.path.join(path, "npzfilestest") -if not os.path.isdir(path): - os.mkdir(path) -path2 = os.path.join(path, "sub") -if not os.path.isdir(path): - os.mkdir(path2) @pytest.mark.parametrize("path", [path]) def test_load_file(path): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - file_path = os.path.join(path, "tsd.npz") + file_path = path / "tsd.npz" tsd.save(file_path) tsd2 = nap.load_file(file_path) @@ -37,7 +40,7 @@ def test_load_file(path): np.testing.assert_array_equal(tsd.values, tsd2.values) np.testing.assert_array_equal(tsd.time_support.values, tsd2.time_support.values) - os.remove(file_path) + # file_path.unlink() @pytest.mark.parametrize("path", [path]) def test_load_file_filenotfound(path): @@ -48,13 +51,13 @@ def test_load_file_filenotfound(path): @pytest.mark.parametrize("path", [path]) def test_load_wrong_format(path): - file_path = os.path.join(path, "test.npy") + file_path = path / "test.npy" np.save(file_path, np.random.rand(10)) with pytest.raises(RuntimeError) as e: nap.load_file(file_path) assert str(e.value) == "File format not supported" - os.remove(file_path) + # file_path.unlink() @pytest.mark.parametrize("path", [path]) def test_load_folder(path): @@ -62,7 +65,7 @@ def test_load_folder(path): assert isinstance(folder, nap.io.Folder) def test_load_folder_foldernotfound(): - with pytest.raises(RuntimeError) as e: + with pytest.raises(FileNotFoundError) as e: nap.load_folder("MissingFolder") assert str(e.value) == "Folder MissingFolder does not exist" diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index 74eec96f..f9de9158 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -11,24 +11,22 @@ import pandas as pd import pytest import warnings -import os +from pathlib import Path +import shutil # look for tests folder -path = os.getcwd() -if os.path.basename(path) == 'pynapple': - path = os.path.join(path, "tests") - -path = os.path.join(path, "npzfilestest") -if not os.path.isdir(path): - os.mkdir(path) -path2 = os.path.join(path, "sub") -if not os.path.isdir(path): - os.mkdir(path2) - -# Cleaning -for root, dirs, files in os.walk(path): - for f in files: - os.remove(os.path.join(root, f)) +path = Path(__file__).parent +if path.name == 'pynapple': + path = path / "tests" +path = path / "npzfilestest" + +# Recursively remove the folder: +shutil.rmtree(path, ignore_errors=True) +path.mkdir(exist_ok=True, parents=True) + +path2 = path.parent / "sub" +path2.mkdir(exist_ok=True, parents=True) + # Populate the folder data = { @@ -43,12 +41,12 @@ "iset":nap.IntervalSet(start=np.array([0.0, 5.0]), end=np.array([1.0, 6.0])) } for k, d in data.items(): - d.save(os.path.join(path, k+".npz")) + d.save(path / (k+".npz")) @pytest.mark.parametrize("path", [path]) def test_init(path): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - file_path = os.path.join(path, "tsd.npz") + file_path = path / "tsd.npz" tsd.save(file_path) file = nap.NPZFile(file_path) assert isinstance(file, nap.NPZFile) @@ -58,18 +56,18 @@ def test_init(path): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsd', 'ts', 'tsdframe', 'tsgroup', 'iset']) def test_load(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, type(data[k])) @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsgroup']) def test_load_tsgroup(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, type(data[k])) assert tmp.keys() == data[k].keys() assert np.all(tmp._metadata == data[k]._metadata) assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys()) @@ -79,10 +77,10 @@ def test_load_tsgroup(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsd']) def test_load_tsd(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, type(data[k])) assert np.all(tmp.d == data[k].d) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -91,10 +89,10 @@ def test_load_tsd(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['ts']) def test_load_ts(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, type(data[k])) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -103,10 +101,10 @@ def test_load_ts(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsdframe']) def test_load_tsdframe(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, type(data[k])) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) assert np.all(tmp.columns == data[k].columns) @@ -116,17 +114,12 @@ def test_load_tsdframe(path, k): @pytest.mark.parametrize("path", [path]) def test_load_non_npz(path): - file_path = os.path.join(path, "random.npz") + file_path = path / "random.npz" tmp = np.random.rand(100) - np.savez(file_path, a = tmp) + np.savez(file_path, a=tmp) file = nap.NPZFile(file_path) assert file.type == "npz" a = file.load() assert isinstance(a, np.lib.npyio.NpzFile) np.testing.assert_array_equal(tmp, a['a']) - - - - - diff --git a/tests/test_numpy_compatibility.py b/tests/test_numpy_compatibility.py index aecb5677..7b1b3777 100644 --- a/tests/test_numpy_compatibility.py +++ b/tests/test_numpy_compatibility.py @@ -11,7 +11,7 @@ # tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6)) -tsd.d[tsd.values>0.9] = np.NaN +tsd.d[tsd.values>0.9] = np.nan @pytest.mark.parametrize( diff --git a/tests/test_nwb.py b/tests/test_nwb.py index 943726e9..a438aa04 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -86,6 +86,7 @@ def test_NWBFile(): assert nwb.name == "A2929-200711" assert isinstance(nwb.io, pynwb.NWBHDF5IO) + nwb.close() def test_NWBFile_missing_file(): @@ -95,7 +96,7 @@ def test_NWBFile_missing_file(): def test_NWBFile_wrong_input(): - with pytest.raises(RuntimeError): + with pytest.raises(TypeError): nap.NWBFile(1) def test_wrong_key(): diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py new file mode 100644 index 00000000..fc76103c --- /dev/null +++ b/tests/test_power_spectral_density.py @@ -0,0 +1,321 @@ +import re +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pandas as pd +import pytest +from scipy import signal + +import pynapple as nap + + +############################################################ +# Test for power_spectral_density +############################################################ + +def get_sorted_fft(data,fs): + fft = np.fft.fft(data, axis=0) + fft_freq = np.fft.fftfreq(len(data), 1 / fs) + order = np.argsort(fft_freq) + if fft.ndim==1: + fft = fft[:,np.newaxis] + return fft_freq[order], fft[order] + +def test_compute_power_spectral_density(): + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a[a>=0]) + np.testing.assert_array_almost_equal(r.values, b[a>=0]) + + r = nap.compute_power_spectral_density(sig, norm=True) + np.testing.assert_array_almost_equal(r.values, b[a>=0]/len(sig)) + + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_power_spectral_density(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape == (500, 4) + + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a[a>=0]) + np.testing.assert_array_almost_equal(r.values, b[a>=0]) + + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_power_spectral_density(sig, full_range=True) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1000, 4) + + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a) + np.testing.assert_array_almost_equal(r.values, b) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig, ep=sig.time_support) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig, fs=1000) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + +@pytest.mark.parametrize( + "sig, fs, ep, full_range, norm, expectation", + [ + ( + nap.Tsd( + d=np.random.random(1000), + t=np.linspace(0, 1, 1000), + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ), + 1000, + None, + False, + False, + pytest.raises( + ValueError, + match=re.escape( + "Given epoch (or signal time_support) must have length 1" + ), + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + "not_ep", + False, + False, + pytest.raises( + TypeError, + match="ep param must be a pynapple IntervalSet object, or None", + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + "a", + None, + False, + False, + pytest.raises( + TypeError, + match="fs must be of type float or int", + ), + ), + ( + "not_a_tsd", + 1000, + None, + False, + False, + pytest.raises( + TypeError, + match="sig must be either a Tsd or a TsdFrame object.", + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + None, + "a", + False, + pytest.raises( + TypeError, + match="full_range must be of type bool or None", + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + None, + False, + "a", + pytest.raises( + TypeError, + match="norm must be of type bool", + ), + ), + ], +) +def test_compute_power_spectral_density_raise_errors( + sig, fs, ep, full_range, norm, expectation +): + with expectation: + psd = nap.compute_power_spectral_density(sig, fs, ep, full_range, norm) + + +############################################################ +# Test for mean_power_spectral_density +############################################################ + + +def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): + t = np.arange(0, duration, 1 / fs) + d = np.cos(2 * np.pi * f * t) + sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) + tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T + # tmp = tmp[0:-1] + tmp = tmp*signal.windows.hamming(tmp.shape[0])[:,np.newaxis] + out = np.sum(np.fft.fft(tmp, axis=0), 1) + freq = np.fft.fftfreq(out.shape[0], 1 / fs) + order = np.argsort(freq) + out = out[order] + freq = freq[order] + return (sig, out, freq) + + +def test_compute_mean_power_spectral_density(): + sig, out, freq = get_signal_and_output() + psd = nap.compute_mean_power_spectral_density(sig, 10) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) + + # Full range + psd = nap.compute_mean_power_spectral_density(sig, 10, full_range=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out) + np.testing.assert_array_almost_equal(psd.index.values, freq) + + # Norm + psd = nap.compute_mean_power_spectral_density(sig, 10, norm=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(10000.0*10.0)) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) + + + # TsdFrame + sig2 = nap.TsdFrame( + t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support + ) + psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) + np.testing.assert_array_almost_equal(psd.index.values, freq) + + # TsdFrame + sig2 = nap.TsdFrame( + t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support + ) + psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True, fs=1000) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) + np.testing.assert_array_almost_equal(psd.index.values, freq) + + +@pytest.mark.parametrize( + "sig, out, freq, interval_size, fs, ep, full_range, norm, time_units, expectation", + [ + (*get_signal_and_output(), 10, None, None, False, False, "s", does_not_raise()), + ( + "a", *get_signal_and_output()[1:], + 10, + None, + None, + False, + False, + "s", + pytest.raises(TypeError, match="sig must be either a Tsd or a TsdFrame object."), + ), + ( + *get_signal_and_output(), + 10, + "a", + None, + False, + False, + "s", + pytest.raises(TypeError, match="fs must be of type float or int"), + ), + ( + *get_signal_and_output(), + 10, + None, + "a", + False, + False, + "s", + pytest.raises( + TypeError, + match="ep param must be a pynapple IntervalSet object, or None", + ), + ), + ( + *get_signal_and_output(), + 10, + None, + None, + "a", + False, + "s", + pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + ( + *get_signal_and_output(), + 10, + None, # FS + None, # Ep + "a", # full_range + False, # Norm + "s", # Time units + pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + ( + *get_signal_and_output(), + 10, + None, # FS + None, # Ep + False, # full_range + "a", # Norm + "s", # Time units + pytest.raises(TypeError, match="norm must be of type bool"), + ), + (*get_signal_and_output(), 10 * 1e3, None, None, False, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10 * 1e6, None, None, False, False, "us", does_not_raise()), + ( + *get_signal_and_output(), + 200, + None, + None, + False, + False, + "s", + pytest.raises( + RuntimeError, + match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size", + ), + ), + ( + *get_signal_and_output(), + 10, + None, + nap.IntervalSet([0, 200], [100, 300]), + False, + False, + "s", + pytest.raises( + RuntimeError, + match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.", + ), + ), + ], +) +def test_compute_mean_power_spectral_density_raise_errors( + sig, out, freq, interval_size, fs, ep, full_range, norm, time_units, expectation +): + with expectation: + psd = nap.compute_mean_power_spectral_density( + sig, interval_size, fs, ep, full_range, norm, time_units + ) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py new file mode 100644 index 00000000..c9c1495b --- /dev/null +++ b/tests/test_signal_processing.py @@ -0,0 +1,540 @@ +"""Tests of `signal_processing` for pynapple""" + +import re +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest + +import pynapple as nap + + +def test_generate_morlet_filterbank(): + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + # Check that peak freq matched expectation + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + # Check that peak freq matched expectation + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + # Check that peak freq matched expectation + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + # Check that peak freq matched expectation + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=3.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + # Check that peak freq matched expectation + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=3.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + # Check that peak freq matched expectation + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + gaussian_atol = 1e-4 + # Checking that the power spectra of the wavelets resemble correct Gaussians + fs = 2000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 1.0 + window_length = 1.0 + fz = power.index + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=gaussian_atol, + ).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 1.0 + window_length = 1.0 + fz = power.index + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=gaussian_atol, + ).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=4.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 4.0 + window_length = 1.0 + fz = power.index + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=gaussian_atol, + ).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=4.0, window_length=3.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 4.0 + window_length = 3.0 + fz = power.index + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=gaussian_atol, + ).all() + + fs = 1000 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.25, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 3.5 + window_length = 1.25 + fz = power.index + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=gaussian_atol, + ).all() + + +@pytest.mark.parametrize( + "freqs, fs, gaussian_width, window_length, precision, expectation", + [ + ( + np.linspace(0, 100, 11), + 1000, + 1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="All frequencies in freqs must be strictly positive" + ), + ), + ( + "a", + 1000, + 1.5, + 1.0, + 16, + pytest.raises(TypeError, match="`freqs` must be a ndarray"), + ), + ( + np.array([]), + 1000, + 1.5, + 1.0, + 16, + pytest.raises(ValueError, match="Given list of freqs cannot be empty."), + ), + ( + np.linspace(1, 10, 1), + "a", + 1.5, + 1.0, + 16, + pytest.raises(TypeError, match="`fs` must be of type float or int ndarray"), + ), + ( + np.linspace(1, 10, 1), + 1000, + -1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="gaussian_width must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + "a", + 1.0, + 16, + pytest.raises( + TypeError, match="gaussian_width must be a float or int instance." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + -1.0, + 16, + pytest.raises(ValueError, match="window_length must be a positive number."), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + "a", + 16, + pytest.raises( + TypeError, match="window_length must be a float or int instance." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + 1.0, + -16, + pytest.raises(ValueError, match="precision must be a positive number."), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + 1.0, + "a", + pytest.raises( + TypeError, match="precision must be a float or int instance." + ), + ), + ], +) +def test_generate_morlet_filterbank_raise_errors( + freqs, fs, gaussian_width, window_length, precision, expectation +): + with expectation: + _ = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width, window_length, precision + ) + + +############################################################ +# Test for compute_wavelet_transform +############################################################ + + +def get_1d_signal(fs=1000, fc=50): + t = np.arange(0, 2, 1 / fs) + d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) + return nap.Tsd(t, d, time_support=nap.IntervalSet(0, 2)) + + +def get_2d_signal(fs=1000, fc=50): + t = np.arange(0, 2, 1 / fs) + d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) + return nap.TsdFrame(t, d[:, np.newaxis], time_support=nap.IntervalSet(0, 2)) + + +def get_output_1d(sig, wavelets): + T = sig.shape[0] + M, N = wavelets.shape + out = [] + for n in range(N): + out.append(np.convolve(sig, wavelets[:, n], mode="full")) + out = np.array(out).T + cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) + return out[cut[0] : cut[1]] + + +def get_output_2d(sig, wavelets): + T, K = sig.shape + M, N = wavelets.shape + out = [] + for k in range(K): + tmp = [] + for n in range(N): + tmp.append(np.convolve(sig[:, k], wavelets[:, n], mode="full")) + out.append(np.array(tmp)) + out = np.array(out).T + cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) + return out[cut[0] : cut[1]] + + +@pytest.mark.parametrize( + "func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt", + [ + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), None, 1.5, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 3.0, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 2.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 20, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, "l1", 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, "l2", 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 20, 1000), + (get_2d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 20, 1000), + ], +) +def test_compute_wavelet_transform( + func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt +): + sig = func(1000, fc) + wavelets = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width, window_length, precision + ) + if sig.ndim == 1: + output = get_output_1d(sig.d, wavelets.values) + if sig.ndim == 2: + output = get_output_2d(sig.d, wavelets.values) + + if norm == "l1": + output = output / (1000 / freqs) + if norm == "l2": + output = output / (1000 / np.sqrt(freqs)) + + mwt = nap.compute_wavelet_transform( + sig, + freqs, + fs=fs, + gaussian_width=gaussian_width, + window_length=window_length, + precision=precision, + norm=norm, + ) + + np.testing.assert_array_almost_equal(output, mwt.values) + assert freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] == fc + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == maxt + ) + np.testing.assert_array_almost_equal( + mwt.time_support.values, sig.time_support.values + ) + + +@pytest.mark.parametrize( + "sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation", + [ + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + 1, + 16, + None, + does_not_raise(), + ), + ( + "a", + np.linspace(1, 10, 2), + 1000, + 1.5, + 1, + 16, + None, + pytest.raises( + TypeError, + match=re.escape( + "`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + "a", + 1.5, + 1, + 16, + None, + pytest.raises( + TypeError, + match=re.escape("`fs` must be of type float or int or None"), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + -1.5, + 1, + 16, + None, + pytest.raises( + ValueError, + match=re.escape("gaussian_width must be a positive number."), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + "a", + 1, + 16, + None, + pytest.raises( + TypeError, + match=re.escape("gaussian_width must be a float or int instance."), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + -1, + 16, + None, + pytest.raises( + ValueError, + match=re.escape("window_length must be a positive number."), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + "a", + 16, + None, + pytest.raises( + TypeError, + match=re.escape("window_length must be a float or int instance."), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + 1, + 16, + "a", + pytest.raises( + ValueError, + match=re.escape("norm parameter must be 'l1', 'l2', or None."), + ), + ), + ( + get_1d_signal(), + "a", + 1000, + 1.5, + 1, + 16, + None, + pytest.raises( + TypeError, + match=re.escape("`freqs` must be a ndarray"), + ), + ), + ( + get_1d_signal(), + np.array([]), + 1000, + 1.5, + 1, + 16, + None, + pytest.raises( + ValueError, + match=re.escape("Given list of freqs cannot be empty."), + ), + ), + ( + get_1d_signal(), + np.array([-1]), + 1000, + 1.5, + 1, + 16, + None, + pytest.raises( + ValueError, + match=re.escape("All frequencies in freqs must be strictly positive"), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + 1, + 16, + 1, + pytest.raises( + ValueError, + match=re.escape("norm parameter must be 'l1', 'l2', or None."), + ), + ), + ], +) +def test_compute_wavelet_transform_raise_errors( + sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation +): + with expectation: + _ = nap.compute_wavelet_transform( + sig, freqs, fs, gaussian_width, window_length, precision, norm + ) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 0d48c727..7e098fe4 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,10 +1,13 @@ """Tests of time series for `pynapple` package.""" -import pynapple as nap +import pickle import numpy as np import pandas as pd import pytest +from pathlib import Path +from contextlib import nullcontext as does_not_raise +import pynapple as nap # tsd1 = nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s") # tsd2 = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), columns = ['a', 'b', 'c']) @@ -385,7 +388,7 @@ def test_dropna(self, tsd): np.testing.assert_array_equal(tsd.values, new_tsd.values) tmp = np.random.rand(*tsd.shape) - tmp[tmp>0.9] = np.NaN + tmp[tmp>0.9] = np.nan tsd = tsd.__class__(t=tsd.t, d=tmp) new_tsd = tsd.dropna() @@ -403,12 +406,36 @@ def test_dropna(self, tsd): np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values) np.testing.assert_array_equal(new_tsd.time_support, tsd.time_support) - tsd = tsd.__class__(t=tsd.t, d=np.ones(tsd.shape)*np.NaN) + tsd = tsd.__class__(t=tsd.t, d=np.ones(tsd.shape)*np.nan) new_tsd = tsd.dropna() assert len(new_tsd) == 0 assert len(new_tsd.time_support) == 0 - def test_convolve(self, tsd): + def test_convolve_raise_errors(self, tsd): + if not isinstance(tsd, nap.Ts): + + with pytest.raises(IOError) as e_info: + tsd.convolve([1,2,3]) + assert str(e_info.value) == "Input should be a numpy array (or jax array if pynajax is installed)." + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.array([])) + assert str(e_info.value) == "Input array is length 0" + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.ones(3), trim='a') + assert str(e_info.value) == "Unknow argument. trim should be 'both', 'left' or 'right'." + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.ones((2,3,4))) + assert str(e_info.value) == "Array should be 1 or 2 dimension." + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.ones(3), ep=[1,2,3,4]) + assert str(e_info.value) == "ep should be an object of type IntervalSet" + + + def test_convolve_1d_kernel(self, tsd): array = np.random.randn(10) if not isinstance(tsd, nap.Ts): tsd2 = tsd.convolve(array) @@ -421,14 +448,6 @@ def test_convolve(self, tsd): tsd2.values.reshape(tsd2.shape[0], -1) ) - with pytest.raises(AssertionError) as e_info: - tsd.convolve([1,2,3]) - assert str(e_info.value) == "Input should be a numpy array (or jax array if pynajax is installed)." - - with pytest.raises(AssertionError) as e_info: - tsd.convolve(np.random.rand(2,3)) - assert str(e_info.value) == "Input should be a one dimensional array." - ep = nap.IntervalSet(start=[0, 60], end=[40,100]) tsd3 = tsd.convolve(array, ep) @@ -456,23 +475,55 @@ def test_convolve(self, tsd): tsd2.values.reshape(tsd2.shape[0], -1) ) - with pytest.raises(AssertionError) as e_info: - tsd.convolve(array, trim='a') - assert str(e_info.value) == "Unknow argument. trim should be 'both', 'left' or 'right'." + def test_convolve_2d_kernel(self, tsd): + array = np.random.randn(10, 3) + if not isinstance(tsd, nap.Ts): + # no epochs + tsd2 = tsd.convolve(array) + tmp = tsd.values.reshape(tsd.shape[0], -1) + + output = [] + + for i in range(tmp.shape[1]): + for j in range(array.shape[1]): + output.append( + np.convolve(tmp[:,i], array[:,j], mode='full')[4:-5] + ) + + output = np.array(output).T + np.testing.assert_array_almost_equal(output,tsd2.values.reshape(tsd.shape[0], -1)) + + # epochs + ep = nap.IntervalSet(start=[0, 60], end=[40,100]) + tsd2 = tsd.convolve(array, ep) + + for k in range(len(ep)): + tmp = tsd.restrict(ep[k]) + tmp2 = tmp.values.reshape(tmp.shape[0], -1) + output = [] + for i in range(tmp2.shape[1]): + for j in range(array.shape[1]): + output.append( + np.convolve(tmp2[:,i], array[:,j], mode='full')[4:-5] + ) + output = np.array(output).T + np.testing.assert_array_almost_equal( + output,tsd2.restrict(ep[k]).values.reshape(tmp.shape[0], -1) + ) def test_smooth(self, tsd): if not isinstance(tsd, nap.Ts): from scipy import signal - tsd2 = tsd.smooth(1) + tsd2 = tsd.smooth(1, size_factor=10) tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) std = int(tsd.rate * 1) - M = std*100 + M = std*11 window = signal.windows.gaussian(M, std=std) window = window / window.sum() for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -490,10 +541,10 @@ def test_smooth(self, tsd): tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) std = int(tsd.rate * 1) - M = std*200 + M = std*201 window = signal.windows.gaussian(M, std=std) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -503,10 +554,10 @@ def test_smooth(self, tsd): tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) std = int(tsd.rate * 1) - M = int(tsd.rate * 10) + M = int(tsd.rate * 11) window = signal.windows.gaussian(M, std=std) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -514,23 +565,23 @@ def test_smooth(self, tsd): def test_smooth_raise_error(self, tsd): if not isinstance(tsd, nap.Ts): - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth('a') assert str(e_info.value) == "std should be type int or float" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, size_factor='b') assert str(e_info.value) == "size_factor should be of type int" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, norm=1) assert str(e_info.value) == "norm should be of type boolean" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, time_units = 0) assert str(e_info.value) == "time_units should be of type str" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, windowsize='a') assert str(e_info.value) == "windowsize should be type int or float" @@ -557,16 +608,6 @@ def test__getitems__(self, tsd): a.time_support, tsd.time_support ) - # def test_loc(self, tsd): - # a = tsd.loc[0:10] # should be 11 elements similar to pandas Series - # b = nap.Tsd(t=tsd.index[0:11], d=tsd.values[0:11]) - # assert isinstance(a, nap.Tsd) - # np.testing.assert_array_almost_equal(a.index, b.index) - # np.testing.assert_array_almost_equal(a.values, b.values) - # pd.testing.assert_frame_equal( - # a.time_support, b.time_support - # ) - def test_count(self, tsd): count = tsd.count(1) assert len(count) == 99 @@ -576,6 +617,11 @@ def test_count(self, tsd): assert len(count) == 99 np.testing.assert_array_almost_equal(count.index, np.arange(0.5, 99, 1)) + count = tsd.count(bin_size=1, dtype=np.int16) + assert len(count) == 99 + assert count.dtype == np.dtype(np.int16) + + def test_count_time_units(self, tsd): for b, tu in zip([1, 1e3, 1e6],['s', 'ms', 'us']): count = tsd.count(b, time_units = tu) @@ -697,28 +743,23 @@ def test_to_tsgroup(self, tsd): np.testing.assert_array_almost_equal(tsgroup[i].index, t[i]) def test_save_npz(self, tsd): - import os - - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsd.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsd.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsd.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsd.save("tsd.npz") - os.listdir('.') - assert "tsd.npz" in os.listdir(".") + assert "tsd.npz" in [f.name for f in Path('.').iterdir()] tsd.save("tsd2") - os.listdir('.') - assert "tsd2.npz" in os.listdir(".") + assert "tsd2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsd.npz") @@ -733,8 +774,8 @@ def test_save_npz(self, tsd): np.testing.assert_array_almost_equal(file['start'], tsd.time_support.start) np.testing.assert_array_almost_equal(file['end'], tsd.time_support.end) - os.remove("tsd.npz") - os.remove("tsd2.npz") + # Path("tsd.npz").unlink() + # Path("tsd2.npz").unlink() def test_interpolate(self, tsd): @@ -757,10 +798,22 @@ def test_interpolate(self, tsd): tsd2 = tsd.interpolate(ts) np.testing.assert_array_almost_equal(tsd2.values, y) - with pytest.raises(AssertionError) as e: + with pytest.raises(IOError) as e: tsd.interpolate([0, 1, 2]) assert str(e.value) == "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" + with pytest.raises(IOError) as e: + tsd.interpolate(ts, left='a') + assert str(e.value) == "Argument left should be of type float or int" + + with pytest.raises(IOError) as e: + tsd.interpolate(ts, right='a') + assert str(e.value) == "Argument right should be of type float or int" + + with pytest.raises(IOError) as e: + tsd.interpolate(ts, ep=[1,2,3,4]) + assert str(e.value) == "ep should be an object of type IntervalSet" + # Right left ep = nap.IntervalSet(start=0, end=5) tsd = nap.Tsd(t=np.arange(1,4), d=np.arange(3), time_support=ep) @@ -926,28 +979,23 @@ def test_bin_average_with_ep(self, tsdframe): np.testing.assert_array_almost_equal(meantsd.values, tmp.loc[np.arange(1,5)].values) def test_save_npz(self, tsdframe): - import os - - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsdframe.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsdframe.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsdframe.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsdframe.save("tsdframe.npz") - os.listdir('.') - assert "tsdframe.npz" in os.listdir(".") + assert "tsdframe.npz" in [f.name for f in Path('.').iterdir()] tsdframe.save("tsdframe2") - os.listdir('.') - assert "tsdframe2.npz" in os.listdir(".") + assert "tsdframe2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsdframe.npz") @@ -964,8 +1012,8 @@ def test_save_npz(self, tsdframe): np.testing.assert_array_almost_equal(file['end'], tsdframe.time_support.end) np.testing.assert_array_almost_equal(file['columns'], tsdframe.columns) - os.remove("tsdframe.npz") - os.remove("tsdframe2.npz") + # Path("tsdframe.npz").unlink() + # Path("tsdframe2.npz").unlink() def test_interpolate(self, tsdframe): @@ -989,7 +1037,7 @@ def test_interpolate(self, tsdframe): tsdframe2 = tsdframe.interpolate(ts) np.testing.assert_array_almost_equal(tsdframe2.values, data_stack) - with pytest.raises(AssertionError) as e: + with pytest.raises(IOError) as e: tsdframe.interpolate([0, 1, 2]) assert str(e.value) == "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" @@ -1021,6 +1069,14 @@ def test_interpolate_with_ep(self, tsdframe): tsdframe2 = tsdframe.interpolate(ts, ep) assert len(tsdframe2) == 0 + def test_convolve_keep_columns(self, tsdframe): + array = np.random.randn(10) + tsdframe = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), time_units="s", columns=['a', 'b', 'c']) + tsd2 = tsdframe.convolve(array) + + assert isinstance(tsd2, nap.TsdFrame) + np.testing.assert_array_equal(tsd2.columns, tsdframe.columns) + #################################################### # Test for ts #################################################### @@ -1041,28 +1097,23 @@ def test_str_(self, ts): assert isinstance(ts.__str__(), str) def test_save_npz(self, ts): - import os - - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: ts.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: ts.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: ts.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) ts.save("ts.npz") - os.listdir('.') - assert "ts.npz" in os.listdir(".") + assert "ts.npz" in [f.name for f in Path('.').iterdir()] ts.save("ts2") - os.listdir('.') - assert "ts2.npz" in os.listdir(".") + assert "ts2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("ts.npz") @@ -1075,8 +1126,8 @@ def test_save_npz(self, ts): np.testing.assert_array_almost_equal(file['start'], ts.time_support.start) np.testing.assert_array_almost_equal(file['end'], ts.time_support.end) - os.remove("ts.npz") - os.remove("ts2.npz") + # Path("ts.npz").unlink() + # Path("ts2.npz").unlink() def test_fillna(self, ts): with pytest.raises(AssertionError): @@ -1111,6 +1162,11 @@ def test_count(self, ts): assert len(count) == 99 np.testing.assert_array_almost_equal(count.index, np.arange(0.5, 99, 1)) + count = ts.count(bin_size=1, dtype=np.int16) + assert len(count) == 99 + assert count.dtype == np.dtype(np.int16) + + def test_count_time_units(self, ts): for b, tu in zip([1, 1e3, 1e6],['s', 'ms', 'us']): count = ts.count(b, time_units = tu) @@ -1145,7 +1201,6 @@ def test_count_with_ep_only(self, ts): assert len(count) == 1 np.testing.assert_array_almost_equal(count.values, np.array([100])) - def test_count_errors(self, ts): with pytest.raises(ValueError): ts.count(bin_size = {}) @@ -1156,6 +1211,24 @@ def test_count_errors(self, ts): with pytest.raises(ValueError): ts.count(time_units = {}) + @pytest.mark.parametrize( + "dtype, expectation", + [ + (None, does_not_raise()), + (float, does_not_raise()), + (int, does_not_raise()), + (np.int32, does_not_raise()), + (np.int64, does_not_raise()), + (np.float32, does_not_raise()), + (np.float64, does_not_raise()), + (1, pytest.raises(ValueError, match=f"1 is not a valid numpy dtype")), + ] + ) + def test_count_dtype(self, dtype, expectation, ts): + with expectation: + count = ts.count(bin_size=0.1, dtype=dtype) + if dtype: + assert np.issubdtype(count.dtype, dtype) #################################################### # Test for tsdtensor @@ -1282,28 +1355,23 @@ def test_bin_average(self, tsdtensor): np.testing.assert_array_almost_equal(meantsd.values, tmp) def test_save_npz(self, tsdtensor): - import os - - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsdtensor.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsdtensor.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsdtensor.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsdtensor.save("tsdtensor.npz") - os.listdir('.') - assert "tsdtensor.npz" in os.listdir(".") + assert "tsdtensor.npz" in [f.name for f in Path('.').iterdir()] tsdtensor.save("tsdtensor2") - os.listdir('.') - assert "tsdtensor2.npz" in os.listdir(".") + assert "tsdtensor2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsdtensor.npz") @@ -1318,8 +1386,8 @@ def test_save_npz(self, tsdtensor): np.testing.assert_array_almost_equal(file['start'], tsdtensor.time_support.start) np.testing.assert_array_almost_equal(file['end'], tsdtensor.time_support.end) - os.remove("tsdtensor.npz") - os.remove("tsdtensor2.npz") + # Path("tsdtensor.npz").unlink() + # Path("tsdtensor2.npz").unlink() def test_interpolate(self, tsdtensor): @@ -1343,7 +1411,7 @@ def test_interpolate(self, tsdtensor): tsdtensor2 = tsdtensor.interpolate(ts) np.testing.assert_array_almost_equal(tsdtensor2.values, data_stack) - with pytest.raises(AssertionError) as e: + with pytest.raises(IOError) as e: tsdtensor.interpolate([0, 1, 2]) assert str(e.value) == "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" @@ -1376,3 +1444,257 @@ def test_interpolate_with_ep(self, tsdtensor): tsdframe2 = tsdtensor.interpolate(ts, ep) assert len(tsdframe2) == 0 +@pytest.mark.parametrize("obj", + [ + nap.Tsd(t=np.arange(10), d=np.random.rand(10), time_units="s"), + nap.TsdFrame( + t=np.arange(10), d=np.random.rand(10, 3), time_units="s", columns=["a","b","c"] + ), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 3, 2), time_units="s"), + ]) +def test_pickling(obj): + """Test that pikling works as expected.""" + # pickle and unpickle ts_group + pickled_obj = pickle.dumps(obj) + unpickled_obj = pickle.loads(pickled_obj) + + # Ensure time is the same + assert np.all(obj.t == unpickled_obj.t) + + # Ensure data is the same + assert np.all(obj.d == unpickled_obj.d) + + # Ensure time support is the same + assert np.all(obj.time_support == unpickled_obj.time_support) + + +#################################################### +# Test for slicing +#################################################### + + +@pytest.mark.parametrize( + "start, end, mode, n_points, expectation", + [ + (1, 3, "closest_t", None, does_not_raise()), + (None, 3, "closest_t", None, pytest.raises(ValueError, match="'start' must be an int or a float")), + (2, "a", "closest_t", None, pytest.raises(ValueError, match="'end' must be an int or a float. Type provided instead!")), + (1, 3, "closest_t", "a", pytest.raises(TypeError, match="'n_points' must be of type int or None. Type provided instead!")), + (1, None, "closest_t", 1, pytest.raises(ValueError, match="'n_points' can be used only when 'end' is specified!")), + (1, 3, "banana", None, pytest.raises(ValueError, match="'mode' only accepts 'before_t', 'after_t', 'closest_t' or 'restrict'.")), + (3, 1, "closest_t", None, pytest.raises(ValueError, match="'start' should not precede 'end'")), + (1, 3, "restrict", 1, pytest.raises(ValueError, match="Fixing the number of time points is incompatible with 'restrict' mode.")), + (1., 3., "closest_t", None, does_not_raise()), + (1., None, "closest_t", None, does_not_raise()), + ] +) +def test_get_slice_raise_errors(start, end, mode, n_points, expectation): + ts = nap.Ts(t=np.array([1, 2, 3, 4])) + with expectation: + ts._get_slice(start, end, mode, n_points) + + +@pytest.mark.parametrize( + "start, end, mode, expected_slice, expected_array", + [ + (1, 3, "after_t", slice(0, 2), np.array([1, 2])), + (1, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (1, 2.7, "after_t", slice(0, 2), np.array([1, 2])), + (1, 2.7, "before_t", slice(0, 1), np.array([1])), + (1, 2.7, "closest_t", slice(0, 2), np.array([1, 2])), + (1, 2.4, "after_t", slice(0, 2), np.array([1, 2])), + (1, 2.4, "before_t", slice(0, 1), np.array([1])), + (1, 2.4, "closest_t", slice(0, 1), np.array([1])), + (1.1, 3, "after_t", slice(1, 2), np.array([2])), + (1.1, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1.1, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (1.6, 3, "after_t", slice(1, 2), np.array([2])), + (1.6, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1.6, 3, "closest_t", slice(1, 2), np.array([2])), + (1.6, 1.8, "before_t", slice(0, 0), np.array([])), + (1.6, 1.8, "after_t", slice(1, 1), np.array([])), + (1.6, 1.8, "closest_t", slice(1, 1), np.array([])), + (1.4, 1.6, "closest_t", slice(0, 1), np.array([1])), + (3, 3, "after_t", slice(2, 2), np.array([])), + (3, 3, "before_t", slice(2, 2), np.array([])), + (3, 3, "closest_t", slice(2, 2), np.array([])), + (0, 3, "after_t", slice(0, 2), np.array([1, 2])), + (0, 3, "before_t", slice(0, 2), np.array([1, 2])), + (0, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (0, 4, "after_t", slice(0, 3), np.array([1, 2, 3])), + (0, 4, "before_t", slice(0, 3), np.array([1, 2, 3])), + (0, 4, "closest_t", slice(0, 3), np.array([1, 2, 3])), + (4, 4, "after_t", slice(3, 3), np.array([])), + (4, 4, "before_t", slice(3, 3), np.array([])), + (4, 4, "closest_t", slice(3, 3), np.array([])), + (4, 5, "after_t", slice(3, 4), np.array([4])), + (4, 5, "before_t", slice(3, 3), np.array([])), + (4, 5, "closest_t", slice(3, 3), np.array([])), + (0, 1, "after_t", slice(0, 0), np.array([])), + (0, 1, "before_t", slice(0, 1), np.array([1])), + (0, 1, "closest_t", slice(0, 0), np.array([])), + (0, None, "after_t", slice(0, 1), np.array([1])), + (0, None, "before_t", slice(0, 0), np.array([])), + (0, None, "closest_t", slice(0, 1), np.array([1])), + (1, None, "after_t", slice(0, 1), np.array([1])), + (1, None, "before_t", slice(0, 1), np.array([1])), + (1, None, "closest_t", slice(0, 1), np.array([1])), + (5, None, "after_t", slice(3, 3), np.array([])), + (5, None, "before_t", slice(3, 4), np.array([4])), + (5, None, "closest_t", slice(3, 4), np.array([4])), + (1, 3, "restrict", slice(0, 3), np.array([1, 2, 3])), + (1, 2.7, "restrict", slice(0, 2), np.array([1, 2])), + (1, 2.4, "restrict", slice(0, 2), np.array([1, 2])), + (1.1, 3, "restrict", slice(1, 3), np.array([2, 3])), + (1.6, 3, "restrict", slice(1, 3), np.array([2, 3])), + (1.6, 1.8, "restrict", slice(1, 1), np.array([])), + (1.4, 1.6, "restrict", slice(1, 1), np.array([])), + (3, 3, "restrict", slice(2, 3), np.array([3])), + (0, 3, "restrict", slice(0, 3), np.array([1, 2, 3])), + (0, 4, "restrict", slice(0, 4), np.array([1, 2, 3, 4])), + (4, 4, "restrict", slice(3, 4), np.array([4])), + (4, 5, "restrict", slice(3, 4), np.array([4])), + (0, 1, "restrict", slice(0, 1), np.array([1])), + + ] +) +@pytest.mark.parametrize("ts", + [ + nap.Ts(t=np.array([1, 2, 3, 4])), + nap.Tsd(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])), + nap.TsdFrame(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None]), + nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) + ]) +def test_get_slice_value(start, end, mode, expected_slice, expected_array, ts): + out_slice = ts._get_slice(start, end=end, mode=mode) + out_array = ts.t[out_slice] + assert out_slice == expected_slice + assert np.all(out_array == expected_array) + if mode == "restrict": + iset = nap.IntervalSet(start, end) + out_restrict = ts.restrict(iset) + assert np.all(out_restrict.t == out_array) + + +def test_get_slice_vs_get_random_val_start_end_value(): + np.random.seed(123) + ts = nap.Ts(np.linspace(0.2, 0.8, 100)) + se_vec = np.random.uniform(0, 1, size=(10000, 2)) + starts = np.min(se_vec, axis=1) + ends = np.max(se_vec, axis=1) + for start, end in zip(starts, ends): + out_slice = ts.get_slice(start=start, end=end) + out_ts = ts[out_slice] + out_get = ts.get(start, end) + assert np.all(out_get.t == out_ts.t) + + +def test_get_slice_vs_get_random_val_start_value(): + np.random.seed(123) + ts = nap.Ts(np.linspace(0.2, 0.8, 100)) + starts = np.random.uniform(0, 1, size=(10000, )) + + for start in starts: + out_slice = ts.get_slice(start=start, end=None) + out_ts = ts[out_slice] + out_get = ts.get(start) + assert np.all(out_get.t == out_ts.t) + + + +@pytest.mark.parametrize( + "end, n_points, expectation", + [ + (1, 3, does_not_raise()), + (None, 3, pytest.raises(ValueError, match="'n_points' can be used only when")), + + ] +) +@pytest.mark.parametrize("time_unit", ["s", "ms", "us"]) +@pytest.mark.parametrize("mode", ["closest_t", "before_t", "after_t"]) +def test_get_slice_n_points(end, n_points, expectation, time_unit, mode): + ts = nap.Ts(t=np.array([1, 2, 3, 4])) + with expectation: + ts._get_slice(1, end, n_points=n_points, mode=mode) + + + +@pytest.mark.parametrize( + "start, end, n_points, mode, expected_slice, expected_array", + [ + # smaller than n_points + (1, 2, 2, "after_t", slice(0, 1), np.array([1])), + (1, 2, 2, "before_t", slice(0, 1), np.array([1])), + (1, 2, 2, "closest_t", slice(0, 1), np.array([1])), + # larger than n_points + (1, 5, 2, "after_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding down + (1, 5.2, 2, "after_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding down + (1, 6.2, 2, "after_t", slice(0, 6, 3), np.array([1, 4])), + (1, 6.2, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 6.2, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding up + (1, 5.6, 2, "after_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding up + (1, 6.6, 2, "after_t", slice(0, 6, 3), np.array([1, 4])), + (1, 6.6, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 6.6, 2, "closest_t", slice(0, 6, 3), np.array([1, 4])), + ] +) +@pytest.mark.parametrize("ts", + [ + nap.Ts(t=np.arange(1, 10)), + nap.Tsd(t=np.arange(1, 10), d=np.arange(1, 10)), + nap.TsdFrame(t=np.arange(1, 10), d=np.arange(1, 10)[:, None]), + nap.TsdTensor(t=np.arange(1, 10), d=np.arange(1, 10)[:, None, None]) + ]) +def test_get_slice_value_step(start, end, n_points, mode, expected_slice, expected_array, ts): + out_slice = ts._get_slice(start, end=end, mode=mode, n_points=n_points) + out_array = ts.t[out_slice] + assert out_slice == expected_slice + assert np.all(out_array == expected_array) + +@pytest.mark.parametrize( + "start, end, expected_slice, expected_array", + [ + (1, 3, slice(0, 3), np.array([1, 2, 3])), + (1, 2.7, slice(0, 2), np.array([1, 2])), + (1, 2.4, slice(0, 2), np.array([1, 2])), + (1.1, 3, slice(1, 3), np.array([2, 3])), + (1.6, 3, slice(1, 3), np.array([2, 3])), + (1.6, 1.8, slice(1, 1), np.array([])), + (1.4, 1.6, slice(1, 1), np.array([])), + (3, 3, slice(2, 3), np.array([3])), + (0, 3, slice(0, 3), np.array([1, 2, 3])), + (0, 4, slice(0, 4), np.array([1, 2, 3, 4])), + (4, 4, slice(3, 4), np.array([4])), + (4, 5, slice(3, 4), np.array([4])), + (0, 1, slice(0, 1), np.array([1])), + (0, None, slice(0, 1), np.array([1])), + (1, None, slice(0, 1), np.array([1])), + (4, None, slice(3, 4), np.array([4])), + (5, None, slice(3, 4), np.array([4])), + (-1, 0, slice(0, 0), np.array([])), + (5, 6, slice(4, 4), np.array([])), + ] +) +@pytest.mark.parametrize("ts", + [ + nap.Ts(t=np.array([1, 2, 3, 4])), + nap.Tsd(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])), + nap.TsdFrame(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None]), + nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) + ]) +def test_get_slice_public(start, end, expected_slice, expected_array, ts): + out_slice = ts.get_slice(start, end=end) + out_array = ts.t[out_slice] + assert out_slice == expected_slice + assert np.all(out_array == expected_array) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index fc221f60..31f54107 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -1,18 +1,19 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-03-30 11:14:41 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-11 14:42:50 + """Tests of ts group for `pynapple` package.""" -import pynapple as nap +import pickle +import warnings +from collections import UserDict +from contextlib import nullcontext as does_not_raise + import numpy as np import pandas as pd import pytest -from collections import UserDict -import warnings -from contextlib import nullcontext as does_not_raise +from pathlib import Path + +import pynapple as nap + @pytest.fixture def group(): @@ -33,6 +34,16 @@ def ts_group(): group = nap.TsGroup(data, meta=[10, 11]) return group + +@pytest.fixture +def ts_group_one_group(): + # Placeholder setup for Ts and Tsd objects. Adjust as necessary. + ts1 = nap.Ts(t=np.arange(10)) + data = {1: ts1} + group = nap.TsGroup(data, meta=[10]) + return group + + class TestTsGroup1: def test_create_ts_group(self, group): @@ -40,10 +51,19 @@ def test_create_ts_group(self, group): assert isinstance(tsgroup, UserDict) assert len(tsgroup) == 3 + def test_create_ts_group_from_iter(self, group): + tsgroup = nap.TsGroup(group.values()) + assert isinstance(tsgroup, UserDict) + assert len(tsgroup) == 3 + + def test_create_ts_group_from_invalid(self): + with pytest.raises(AttributeError): + tsgroup = nap.TsGroup(np.arange(0, 200)) + @pytest.mark.parametrize( "test_dict, expectation", [ - ({"1": nap.Ts(np.arange(10)), "2":nap.Ts(np.arange(10))}, does_not_raise()), + ({"1": nap.Ts(np.arange(10)), "2": nap.Ts(np.arange(10))}, does_not_raise()), ({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}, does_not_raise()), ({"1": nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))}, pytest.raises(ValueError, match="Two dictionary keys contain the same integer")), @@ -71,7 +91,6 @@ def test_initialize_from_dict(self, test_dict, expectation): def test_metadata_len_match(self, tsgroup): assert len(tsgroup._metadata) == len(tsgroup) - def test_create_ts_group_from_array(self): with warnings.catch_warnings(record=True) as w: nap.TsGroup({ @@ -123,7 +142,8 @@ def test_create_ts_group_with_metainfo(self, group): ar_info = np.ones(3) * 1 tsgroup = nap.TsGroup(group, sr=sr_info, ar=ar_info) assert tsgroup._metadata.shape == (3, 3) - pd.testing.assert_series_equal(tsgroup._metadata["sr"], sr_info) + np.testing.assert_array_almost_equal(tsgroup._metadata["sr"].values, sr_info.values) + np.testing.assert_array_almost_equal(tsgroup._metadata["sr"].index.values, sr_info.index.values) np.testing.assert_array_almost_equal(tsgroup._metadata["ar"].values, ar_info) def test_add_metainfo(self, group): @@ -290,6 +310,9 @@ def test_count(self, group): count = tsgroup.count() np.testing.assert_array_almost_equal(count.values, np.array([[101, 201, 501]])) + count = tsgroup.count(1.0, dtype=np.int16) + assert count.dtype == np.dtype(np.int16) + def test_count_with_ep(self, group): ep = nap.IntervalSet(start=0, end=100) tsgroup = nap.TsGroup(group) @@ -515,8 +538,6 @@ def test_to_tsd_runtime_errors(self, group): def test_save_npz(self, group): - import os - group = { 0: nap.Tsd(t=np.arange(0, 20), d = np.random.rand(20)), 1: nap.Tsd(t=np.arange(0, 20, 0.5), d=np.random.rand(40)), @@ -525,26 +546,23 @@ def test_save_npz(self, group): tsgroup = nap.TsGroup(group, meta = np.arange(len(group), dtype=np.int64), meta2 = np.array(['a', 'b', 'c'])) - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsgroup.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsgroup.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsgroup.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsgroup.save("tsgroup.npz") - os.listdir('.') - assert "tsgroup.npz" in os.listdir(".") + assert "tsgroup.npz" in [f.name for f in Path('.').iterdir()] tsgroup.save("tsgroup2") - os.listdir('.') - assert "tsgroup2.npz" in os.listdir(".") + assert "tsgroup2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsgroup.npz") @@ -585,9 +603,9 @@ def test_save_npz(self, group): assert 'd' not in list(file.keys()) np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) - os.remove("tsgroup.npz") - os.remove("tsgroup2.npz") - os.remove("tsgroup3.npz") + Path("tsgroup.npz").unlink() + Path("tsgroup2.npz").unlink() + Path("tsgroup3.npz").unlink() @pytest.mark.parametrize( "keys, expectation", @@ -854,4 +872,51 @@ def test_merge_time_support(self, ts_group, time_support, reset_time_support, ex np.testing.assert_array_almost_equal( ts_group.time_support.as_units("s").to_numpy(), merged.time_support.as_units("s").to_numpy() - ) \ No newline at end of file + ) + + +def test_pickling(ts_group): + """Test that pikling works as expected.""" + # pickle and unpickle ts_group + pickled_obj = pickle.dumps(ts_group) + unpickled_obj = pickle.loads(pickled_obj) + + # Ensure the type is the same + assert type(ts_group) is type(unpickled_obj), "Types are different" + + # Ensure that TsGroup have same len + assert len(ts_group) == len(unpickled_obj) + + # Ensure that metadata content is the same + assert np.all(unpickled_obj._metadata == ts_group._metadata) + + # Ensure that metadata columns are the same + assert np.all(unpickled_obj._metadata.columns == ts_group._metadata.columns) + + # Ensure that the Ts are the same + assert all([np.all(ts_group[key].t == unpickled_obj[key].t) for key in unpickled_obj.keys()]) + + # Ensure time support is the same + assert np.all(ts_group.time_support == unpickled_obj.time_support) + + +@pytest.mark.parametrize( + "dtype, expectation", + [ + (None, does_not_raise()), + (float, does_not_raise()), + (int, does_not_raise()), + (np.int32, does_not_raise()), + (np.int64, does_not_raise()), + (np.float32, does_not_raise()), + (np.float64, does_not_raise()), + (1, pytest.raises(ValueError, match=f"1 is not a valid numpy dtype")), + ] +) +def test_count_dtype(dtype, expectation, ts_group, ts_group_one_group): + with expectation: + count = ts_group.count(bin_size=0.1, dtype=dtype) + count_one = ts_group_one_group.count(bin_size=0.1, dtype=dtype) + if dtype: + assert np.issubdtype(count.dtype, dtype) + assert np.issubdtype(count_one.dtype, dtype) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..71c06027 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,20 @@ +"""Tests of utils for `pynapple` package.""" + +import pynapple as nap +import numpy as np +import pandas as pd +import pytest + +def test_get_backend(): + assert nap.core.utils.get_backend() in ["numba", "jax"] + +def test_is_array_like(): + assert nap.core.utils.is_array_like(np.ones(3)) + assert nap.core.utils.is_array_like(np.array([])) + assert not nap.core.utils.is_array_like([1,2,3]) + assert not nap.core.utils.is_array_like(1) + assert not nap.core.utils.is_array_like('a') + assert not nap.core.utils.is_array_like(True) + assert not nap.core.utils.is_array_like((1,2,3)) + assert not nap.core.utils.is_array_like({0:1}) + assert not nap.core.utils.is_array_like(np.array(0)) \ No newline at end of file