From a297c4fadc5957723bd8c75828e65f774f50f258 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Tue, 22 Nov 2022 01:29:13 +0200 Subject: [PATCH] [MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved spectral connectivity estimation (#104) --- doc/authors.inc | 1 + doc/conf.py | 3 +- doc/whats_new.rst | 18 +- environment.yml | 4 +- mne_connectivity/conftest.py | 2 +- mne_connectivity/spectral/epochs.py | 13 +- .../spectral/tests/test_spectral.py | 210 +++--- mne_connectivity/spectral/time.py | 631 ++++++++++++------ requirements.txt | 2 +- requirements_doc.txt | 2 +- tools/circleci_dependencies.sh | 2 +- 11 files changed, 568 insertions(+), 320 deletions(-) diff --git a/doc/authors.inc b/doc/authors.inc index a4caeb5c..096e5d55 100644 --- a/doc/authors.inc +++ b/doc/authors.inc @@ -7,3 +7,4 @@ .. _Szonja Weigl: https://github.com/weiglszonja .. _Kenji Marshall: https://github.com/kenjimarshall .. _Sezan Mert: https://github.com/SezanMert +.. _Santeri Ruuskanen: https://github.com/ruuskas diff --git a/doc/conf.py b/doc/conf.py index 7532bfb3..1fc9fe0a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -197,7 +197,8 @@ 'use_edit_page_button': False, 'navigation_with_keys': False, 'show_toc_level': 1, - 'navbar_end': ['version-switcher', 'navbar-icon-links'], + 'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'], + 'secondary_sidebar_items': ['page-toc'], } # Custom sidebar templates, maps document names to template names. html_sidebars = { diff --git a/doc/whats_new.rst b/doc/whats_new.rst index e360570c..2e8031fd 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -19,27 +19,35 @@ Here we list a changelog of MNE-connectivity. Version 0.5 (Unreleased) ------------------------ -... +This version has major changes in :func:`mne_connectivity.spectral_connectivity_time`. Several bugs are fixed, and the +function now computes static connectivity over time, as opposed to static connectivity over trials computed by :func:`mne_connectivity.spectral_connectivity_epochs`. Enhancements ~~~~~~~~~~~~ -- +- Add the ``PLI`` and ``wPLI`` methods in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). Bug ~~~ -- +- When using the ``multitaper`` mode in :func:`mne_connectivity.spectral_connectivity_time`, average CSD over tapers instead of the complex signal by `Santeri Ruuskanen`_ (:gh:`104`). +- Average over time when computing connectivity measures in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Fix support for multiple connectivity methods in calls to :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Fix bug with the ``indices`` parameter in :func:`mne_connectivity.spectral_connectivity_time`, the behavior is now as expected by `Santeri Ruuskanen`_ (:gh:`104`). +- Fix bug with parallel computation in :func:`mne_connectivity.spectral_connectivity_time`, add instructions for memory mapping in doc by `Santeri Ruuskanen`_ (:gh:`104`). API ~~~ -- +- Streamline the API of :func:`mne_connectivity.spectral_connectivity_time` with :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). Authors ~~~~~~~ -* +* `Santeri Ruuskanen`_ :doc:`Find out what was new in previous releases ` diff --git a/environment.yml b/environment.yml index 3fe4ed24..8fb68645 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: mne +name: mne-connectivity channels: - conda-forge dependencies: @@ -20,5 +20,5 @@ dependencies: - pyvista>=0.32 - pyvistaqt>=0.4 - pyqt!=5.15.3 -- mne +- mne>=1.0 - h5netcdf diff --git a/mne_connectivity/conftest.py b/mne_connectivity/conftest.py index 33487f5a..d1366cfb 100644 --- a/mne_connectivity/conftest.py +++ b/mne_connectivity/conftest.py @@ -151,6 +151,6 @@ def _check_skip_backend(name): if not has_imageio_ffmpeg(): pytest.skip("Test skipped, requires imageio-ffmpeg") if name == 'pyvistaqt' and not _check_qt_version(): - pytest.skip("Test skipped, requires PyQt5.") + pytest.skip("Test skipped, requires Python Qt bindings.") if name == 'pyvistaqt' and not has_pyvistaqt(): pytest.skip("Test skipped, requires pyvistaqt") diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 8c5963ff..2237d874 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -859,6 +859,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, See Also -------- + mne_connectivity.spectral_connectivity_time mne_connectivity.SpectralConnectivity mne_connectivity.SpectroTemporalConnectivity @@ -873,7 +874,9 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, connectivity structure. Within each Epoch, it is assumed that the spectral measure is stationary. The spectral measures implemented in this function are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values.** + one Epoch will result in errorful values and spectral measures computed + with few Epochs will be unreliable.** Please see + ``spectral_connectivity_time`` for time-resolved connectivity estimation. The spectral densities can be estimated using a multitaper method with digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier @@ -891,11 +894,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, indices = (np.array([0, 0, 0]), # row indices np.array([2, 3, 4])) # col indices - con_flat = spectral_connectivity(data, method='coh', - indices=indices, ...) + con = spectral_connectivity_epochs(data, method='coh', + indices=indices, ...) - In this case con_flat.shape = (3, n_freqs). The connectivity scores are - in the same order as defined indices. + In this case con.get_data().shape = (3, n_freqs). The connectivity scores + are in the same order as defined indices. **Supported Connectivity Measures** diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index b32b0f62..8b4c71a8 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -2,14 +2,8 @@ from numpy.testing import (assert_allclose, assert_array_almost_equal, assert_array_less) import pytest -import warnings - -import mne -from mne import (EpochsArray, SourceEstimate, create_info, - make_fixed_length_epochs) +from mne import (EpochsArray, SourceEstimate, create_info) from mne.filter import filter_data -from mne.utils import _resource_path -from mne_bids import BIDSPath, read_raw_bids from mne_connectivity import ( SpectralConnectivity, spectral_connectivity_epochs, @@ -478,7 +472,101 @@ def test_epochs_tmin_tmax(kind): assert len(w) == 1 # just one even though there were multiple epochs -@pytest.mark.parametrize('method', ['coh', 'plv']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize('data_option', ['sync', 'random']) +def test_spectral_connectivity_time_phaselocked(method, mode, data_option): + """Test time-resolved spectral connectivity with simulated phase-locked + data.""" + rng = np.random.default_rng(0) + n_epochs = 5 + n_channels = 3 + n_times = 1000 + sfreq = 250 + data = np.zeros((n_epochs, n_channels, n_times)) + if data_option == 'random': + # Data is random, there should be no consistent phase differences. + data = rng.random((n_epochs, n_channels, n_times)) + if data_option == 'sync': + # Data consists of phase-locked 10Hz sine waves with constant phase + # difference within each epoch. + wave_freq = 10 + epoch_length = n_times / sfreq + for i in range(n_epochs): + for c in range(n_channels): + phase = rng.random() * 10 + x = np.linspace(-wave_freq * epoch_length * np.pi + phase, + wave_freq * epoch_length * np.pi + phase, + n_times) + data[i, c] = np.squeeze(np.sin(x)) + # the frequency band should contain the frequency at which there is a + # hypothesized "connection" + freq_band_low_limit = (8.) + freq_band_high_limit = (13.) + cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) + con = spectral_connectivity_time(data, method=method, mode=mode, + sfreq=sfreq, fmin=freq_band_low_limit, + fmax=freq_band_high_limit, + cwt_freqs=cwt_freqs, n_jobs=1, + faverage=True, average=True, sm_times=0) + assert con.shape == (n_channels ** 2, len(con.freqs)) + con_matrix = con.get_data('dense')[..., 0] + if data_option == 'sync': + # signals are perfectly phase-locked, connectivity matrix should be + # a lower triangular matrix of ones + assert np.allclose(con_matrix, + np.tril(np.ones(con_matrix.shape), + k=-1), + atol=0.01) + if data_option == 'random': + # signals are random, all connectivity values should be small + # 0.5 is picked rather arbitrarily such that the obsolete wrong + # implementation fails + assert np.all(con_matrix) <= 0.5 + + +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) +def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): + """Test time-resolved spectral connectivity with int and float values for + cwt_freqs.""" + rng = np.random.default_rng(0) + n_epochs = 5 + n_channels = 3 + n_times = 1000 + sfreq = 250 + data = np.zeros((n_epochs, n_channels, n_times)) + + # Data consists of phase-locked 10Hz sine waves with constant phase + # difference within each epoch. + wave_freq = 10 + epoch_length = n_times / sfreq + for i in range(n_epochs): + for c in range(n_channels): + phase = rng.random() * 10 + x = np.linspace(-wave_freq * epoch_length * np.pi + phase, + wave_freq * epoch_length * np.pi + phase, + n_times) + data[i, c] = np.squeeze(np.sin(x)) + # the frequency band should contain the frequency at which there is a + # hypothesized "connection" + con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', + sfreq=sfreq, fmin=np.min(cwt_freqs), + fmax=np.max(cwt_freqs), + cwt_freqs=cwt_freqs, n_jobs=1, + faverage=True, average=True, sm_times=0) + assert con.shape == (n_channels ** 2, len(con.freqs)) + con_matrix = con.get_data('dense')[..., 0] + + # signals are perfectly phase-locked, connectivity matrix should be + # a lower triangular matrix of ones + assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), + atol=0.01) + + +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) def test_spectral_connectivity_time_resolved(method, mode): @@ -486,7 +574,7 @@ def test_spectral_connectivity_time_resolved(method, mode): sfreq = 50. n_signals = 3 n_epochs = 2 - n_times = 256 + n_times = 1000 trans_bandwidth = 2. tmin = 0. tmax = (n_times - 1) / sfreq @@ -502,22 +590,21 @@ def test_spectral_connectivity_time_resolved(method, mode): # define some frequencies for cwt freqs = np.arange(3, 20.5, 1) - n_freqs = len(freqs) # run connectivity estimation con = spectral_connectivity_time( - data, freqs=freqs, method=method, mode=mode) - assert con.shape == (n_epochs, n_signals * 2, n_freqs, n_times) + data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode, + n_cycles=5) + assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ - (n_epochs, n_signals, n_signals, n_freqs, n_times) - - # average over time - conn_data = con.get_data(output='dense').mean(axis=-1) - conn_data = conn_data.mean(axis=-1) + (n_epochs, n_signals, n_signals, len(con.freqs)) # test the simulated signal triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T + # average over frequencies + conn_data = con.get_data(output='dense').mean(axis=-1) + # the indices at which there is a correlation should be greater # then the rest of the components for epoch_idx in range(n_epochs): @@ -526,95 +613,6 @@ def test_spectral_connectivity_time_resolved(method, mode): for idx, jdx in triu_inds) -@pytest.mark.parametrize('method', ['coh', 'plv']) -@pytest.mark.parametrize( - 'mode', ['morlet', 'multitaper']) -def test_time_resolved_spectral_conn_regression(method, mode): - """Regression test against original implementation in Frites. - - To see how the test dataset was generated, see - ``benchmarks/single_epoch_conn.py``. - """ - test_file_path_str = str(_resource_path( - 'mne_connectivity.tests', - f'data/test_frite_dataset_{mode}_{method}.npy')) - test_conn = np.load(test_file_path_str) - - # paths to mne datasets - sample ECoG - bids_root = mne.datasets.epilepsy_ecog.data_path() - - # first define the BIDS path and load in the dataset - bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery', - task='ictal', datatype='ieeg', extension='.vhdr') - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - raw = read_raw_bids(bids_path=bids_path, verbose=False) - line_freq = raw.info['line_freq'] - - # Pick only the ECoG channels, removing the ECG channels - raw.pick_types(ecog=True) - - # drop bad channels - raw.drop_channels(raw.info['bads']) - - # only pick the first three channels to lower RAM usage - raw = raw.pick_channels(raw.ch_names[:3]) - - # Load the data - raw.load_data() - - # Then we remove line frequency interference - raw.notch_filter(line_freq) - - # crop data and then Epoch - raw_copy = raw.copy() - raw = raw.crop(tmin=0, tmax=4, include_tmax=False) - epochs = make_fixed_length_epochs(raw=raw, duration=2., overlap=1.) - - ###################################################################### - # Perform basic test to match simulation data using time-resolved spec - ###################################################################### - # compare data to original run using Frites - freqs = [30, 90] - - # mode was renamed in mne-connectivity - if mode == 'morlet': - mode = 'cwt_morlet' - conn = spectral_connectivity_time( - epochs, freqs=freqs, n_jobs=1, method=method, mode=mode) - - # frites only stores the upper triangular parts of the raveled array - row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1) - conn_data = conn.get_data(output='dense')[ - :, row_triu_inds, col_triu_inds, ...] - assert_array_almost_equal(conn_data, test_conn) - - ###################################################################### - # Give varying set of frequency bands and frequencies to perform cWT - ###################################################################### - raw = raw_copy.crop(tmin=0, tmax=10, include_tmax=False) - ch_names = epochs.ch_names - epochs = make_fixed_length_epochs(raw=raw, duration=5, overlap=0.) - - # sampling rate of my data - sfreq = raw.info['sfreq'] - - # frequency bands of interest - fois = np.array([[4, 8], [8, 12], [12, 16], [16, 32]]) - - # frequencies of Continuous Morlet Wavelet Transform - freqs = np.arange(4., 32., 1) - - # compute coherence - cohs = spectral_connectivity_time( - epochs, names=None, method=method, indices=None, - sfreq=sfreq, foi=fois, sm_times=0.5, sm_freqs=1, sm_kernel='hanning', - mode=mode, mt_bandwidth=None, freqs=freqs, n_cycles=5) - assert cohs.get_data(output='dense').shape == ( - len(epochs), len(ch_names), len(ch_names), len(fois), len(epochs.times) - ) - - def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index f1a53bce..1c16052a 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -1,137 +1,286 @@ # Authors: Adam Li +# Santeri Ruuskanen # # License: BSD (3-clause) import numpy as np import xarray as xr +from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper) -from mne.utils import logger +from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper, + dpss_windows) +from mne.utils import (logger, warn) -from ..base import (EpochSpectroTemporalConnectivity) +from ..base import (SpectralConnectivity, EpochSpectralConnectivity) +from .epochs import _compute_freqs, _compute_freq_mask from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, fill_doc @fill_doc -def spectral_connectivity_time(data, names=None, method='coh', indices=None, - sfreq=2 * np.pi, foi=None, sm_times=.5, +def spectral_connectivity_time(data, method='coh', average=False, + indices=None, sfreq=None, fmin=None, + fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', mode='cwt_morlet', mt_bandwidth=None, - freqs=None, n_cycles=7, decim=1, - block_size=None, n_jobs=1, - verbose=None): + cwt_freqs=None, n_cycles=7, decim=1, + n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. - This method computes single-Epoch time-resolved spectral connectivity. + This method computes time-resolved connectivity measures from epoched data. - The connectivity method(s) are specified using the "method" parameter. + The connectivity method(s) are specified using the ``method`` parameter. All methods are based on estimates of the cross- and power spectral densities (CSD/PSD) Sxy and Sxx, Syy. Parameters ---------- - data : Epochs + data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. - %(names)s method : str | list of str - Connectivity measure(s) to compute. These can be ``['coh', 'plv', - 'sxy']``. These are: + Connectivity measure(s) to compute. These can be + ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) * 'sxy' : Cross-spectrum - - By default, the coherence is used. - indices : tuple of array | None + * 'pli' : Phase-Lag Index + * 'wpli': Weighted Phase-Lag Index + average : bool + Average connectivity scores over epochs. If True, output will be + an instance of :class:`SpectralConnectivity`, otherwise + :class:`EpochSpectralConnectivity`. + indices : tuple of array_like | None Two arrays with indices of connections for which to compute connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. - If None, all connections are computed. + If `None`, all connections are computed. sfreq : float - The sampling frequency. - foi : array_like | None - Extract frequencies of interest. This parameters should be an array of - shapes (n_foi, 2) defining where each band of interest start and - finish. + The sampling frequency. Required if data is not + :class:`Epochs `. + fmin : float | tuple of float | None + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower + bounds. If `None`, the frequency corresponding to an epoch length of + 5 cycles is used. + fmax : float | tuple of float | None + The upper frequency of interest. Multiple bands are defined using + a tuple, e.g. ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper + bounds. If `None`, ``sfreq/2`` is used. + fskip : int + Omit every ``(fskip + 1)``-th frequency bin to decimate in frequency + domain. + faverage : bool + Average connectivity scores for each frequency band. If `True`, + the output ``freqs`` will be an array of the median frequencies of each + band. sm_times : float - Amount of time to consider for the temporal smoothing in seconds. By - default, 0.5 sec smoothing is used. + Amount of time to consider for the temporal smoothing in seconds. + If zero, no temporal smoothing is applied. sm_freqs : int Number of points for frequency smoothing. By default, 1 is used which is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} - Kernel type to use. Choose either 'square' or 'hanning' (default). - mode : str, optional - Spectrum estimation mode can be either: 'multitaper', or - 'cwt_morlet'. + Smoothing kernel type. Choose either 'square' or 'hanning'. + mode : str + Time-frequency decomposition method. Can be either: 'multitaper', or + 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. - freqs : array - Array of frequencies of interest for use in time-frequency - decomposition method (specified by ``mode``). - n_cycles : float | array of float - Number of cycles for use in time-frequency decomposition method - (specified by ``mode``). Fixed number or one per frequency. - decim : int | 1 + Product between the temporal window length (in seconds) and the full + frequency bandwidth (in Hz). This product can be seen as the surface + of the window on the time/frequency plane and controls the frequency + bandwidth (thus the frequency resolution) and the number of good + tapers. See :func:`mne.time_frequency.tfr_array_multitaper` + documentation. + cwt_freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only used in 'cwt_morlet' mode. Only the frequencies within + the range specified by ``fmin`` and ``fmax`` are used. Required if + ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. + n_cycles : float | array_like of float + Number of cycles in the wavelet, either a fixed number or one per + frequency. The number of cycles ``n_cycles`` and the frequencies of + interest ``cwt_freqs`` define the temporal window length. For details, + see :func:`mne.time_frequency.tfr_array_morlet` documentation. + decim : int To reduce memory usage, decimation factor after time-frequency - decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, - returns tfr[…, decim]. - block_size : int - How many connections to compute at once (higher numbers are faster - but require more memory). + decomposition. Returns ``tfr[…, ::decim]``. n_jobs : int - How many epochs to process in parallel. + Number of connections to compute in parallel. Memory mapping must be + activated. Please see the Notes section for details. %(verbose)s Returns ------- - con : array | instance of Connectivity - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of each connectivity dataset is either - (n_signals ** 2, n_freqs) mode: 'multitaper' or 'fourier' - (n_signals ** 2, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is None, or - (n_con, n_freqs) mode: 'multitaper' or 'fourier' - (n_con, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is specified and "n_con = len(indices[0])". + con : instance of Connectivity | list + Computed connectivity measure(s). An instance of + :class:`EpochSpectralConnectivity`, :class:`SpectralConnectivity` + or a list of instances corresponding to connectivity measures if + several connectivity measures are specified. + The shape of each connectivity dataset is + (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None` + and (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified + and ``n_nodes = len(indices[0])``. See Also -------- mne_connectivity.spectral_connectivity_epochs mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity + mne_connectivity.EpochSpectralConnectivity Notes ----- - This function was originally implemented in ``frites`` and was - ported over. + Please note that the interpretation of the measures in this function + depends on the data and underlying assumptions and does not necessarily + reflect a causal relationship between brain regions. + + The connectivity measures are computed over time within each epoch and + optionally averaged over epochs. High connectivity values indicate that + the phase coupling (interpreted as estimated connectivity) differences + between signals stay consistent over time. + + The spectral densities can be estimated using a multitaper method with + digital prolate spheroidal sequence (DPSS) windows, or a continuous wavelet + transform using Morlet wavelets. The spectral estimation mode is specified + using the ``mode`` parameter. + + When using the multitaper spectral estimation method, the + cross-spectral density is computed separately for each taper and aggregated + using a weighted average, where the weights correspond to the concentration + ratios between the DPSS windows. + + By default, the connectivity between all signals is computed (only + connections corresponding to the lower-triangular part of the + connectivity matrix). If one is only interested in the connectivity + between some signals, the ``indices`` parameter can be used. For example, + to compute the connectivity between the signal with index 0 and signals + 2, 3, 4 (a total of 3 connections), one can use the following:: + + indices = (np.array([0, 0, 0]), # row indices + np.array([2, 3, 4])) # col indices + + con = spectral_connectivity_time(data, method='coh', + indices=indices, ...) + + In this case ``con.get_data().shape = (3, n_freqs)``. The connectivity + scores are in the same order as defined indices. + + **Supported Connectivity Measures** + + The connectivity method(s) is specified using the ``method`` parameter. The + following methods are supported (note: ``E[]`` denotes average over + epochs). Multiple measures can be computed at once by using a list/tuple, + e.g., ``['coh', 'pli']`` to compute coherence and PLI. + + 'coh' : Coherence given by:: + + | E[Sxy] | + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given + by:: + + PLV = |E[Sxy/|Sxy|]| + + 'sxy' : Cross spectrum Sxy + + 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: + + PLI = |E[sign(Im(Sxy))]| + + 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` + given by:: + + |E[Im(Sxy)]| + WPLI = ------------------ + E[|Im(Sxy)|] + + Parallel computation can be activated by setting the ``n_jobs`` parameter. + Under the hood, this utilizes the ``joblib`` library. For effective + parallelization, you should activate memory mapping in MNE-Python by + setting ``MNE_MEMMAP_MIN_SIZE`` and ``MNE_CACHE_DIR``. Activating memory + mapping will make ``joblib`` store arrays greater than the minimum size on + disc, and forego direct RAM access for more efficient processing. + For example, in your code, run + + mne.set_config('MNE_MEMMAP_MIN_SIZE', '10M') + mne.set_config('MNE_CACHE_DIR', '/dev/shm') + + When ``MNE_MEMMAP_MIN_SIZE=None``, the underlying joblib implementation + results in pickling and unpickling the whole array each time a pair of + indices is accessed, which is slow, compared to memory mapping the array. + + This function is based on the ``frites.conn.conn_spec`` implementation in + Frites. .. versionadded:: 0.3 + + References + ---------- + .. footbibliography:: """ events = None event_id = None # extract data from Epochs object - names = data.ch_names - times = data.times # input times for Epochs input type - sfreq = data.info['sfreq'] - events = data.events - event_id = data.event_id - n_epochs, n_signals, n_times = data.get_data().shape - # Extract metadata from the Epochs data structure. - # Make Annotations persist through by adding them to the metadata. - metadata = data.metadata - if metadata is None: - annots_in_metadata = False + if isinstance(data, BaseEpochs): + names = data.ch_names + sfreq = data.info['sfreq'] + events = data.events + event_id = data.event_id + n_epochs, n_signals, n_times = data.get_data().shape + # Extract metadata from the Epochs data structure. + # Make Annotations persist through by adding them to the metadata. + metadata = data.metadata + if metadata is None: + annots_in_metadata = False + else: + annots_in_metadata = all( + name not in metadata.columns for name in [ + 'annot_onset', 'annot_duration', 'annot_description']) + if hasattr(data, 'annotations') and not annots_in_metadata: + data.add_annotations_to_metadata(overwrite=True) + metadata = data.metadata + data = data.get_data() else: - annots_in_metadata = all( - name not in metadata.columns for name in [ - 'annot_onset', 'annot_duration', 'annot_description']) - if hasattr(data, 'annotations') and not annots_in_metadata: - data.add_annotations_to_metadata(overwrite=True) - metadata = data.metadata - data = data.get_data() + data = np.asarray(data) + n_epochs, n_signals, n_times = data.shape + names = np.arange(0, n_signals) + metadata = None + if sfreq is None: + raise ValueError('Sampling frequency (sfreq) is required with ' + 'array input.') + + # check that method is a list + if isinstance(method, str): + method = [method] + + # check that fmin corresponds to at least 5 cycles + dur = float(n_times) / sfreq + five_cycle_freq = 5. / dur + if fmin is None: + # use the 5 cycle freq. as default + fmin = five_cycle_freq + logger.info(f'Fmin was not specified. Using fmin={fmin:.2f}, which ' + 'corresponds to at least five cycles.') + else: + if np.any(fmin < five_cycle_freq): + warn('fmin=%0.3f Hz corresponds to %0.3f < 5 cycles ' + 'based on the epoch length %0.3f sec, need at least %0.3f ' + 'sec epochs or fmin=%0.3f. Spectrum estimate will be ' + 'unreliable.' % (np.min(fmin), dur * np.min(fmin), dur, + 5. / np.min(fmin), five_cycle_freq)) + if fmax is None: + fmax = sfreq / 2 + logger.info(f'Fmax was not specified. Using fmax={fmax:.2f}, which ' + f'corresponds to Nyquist.') + + fmin = np.array((fmin,), dtype=float).ravel() + fmax = np.array((fmax,), dtype=float).ravel() + if len(fmin) != len(fmax): + raise ValueError('fmin and fmax must have the same length') + if np.any(fmin > fmax): + raise ValueError('fmax must be larger than fmin') # convert kernel width in time to samples if isinstance(sm_times, (int, float)): @@ -143,7 +292,6 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, # temporal decimation if isinstance(decim, int): - times = times[::decim] sm_times = int(np.round(sm_times / decim)) sm_times = max(sm_times, 1) @@ -151,135 +299,155 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, kernel = _create_kernel(sm_times, sm_freqs, kernel=sm_kernel) # get indices of pairs of (group) regions - roi = names # ch_names if indices is None: - # roi_gp and roi_idx - roi_gp, _ = roi, np.arange(len(roi)).reshape(-1, 1) - - # get pairs for directed / undirected conn - source_idx, target_idx = np.triu_indices(len(roi_gp), k=0) + indices_use = np.tril_indices(n_signals, k=-1) else: indices_use = check_indices(indices) - source_idx = [x[0] for x in indices_use] - target_idx = [x[1] for x in indices_use] - roi_gp, _ = roi, np.arange(len(roi)).reshape(-1, 1) + source_idx = indices_use[0] + target_idx = indices_use[1] n_pairs = len(source_idx) - # frequency checking - if freqs is not None: + # check cwt_freqs + if cwt_freqs is not None: # check for single frequency - if isinstance(freqs, (int, float)): - freqs = [freqs] + if isinstance(cwt_freqs, (int, float)): + cwt_freqs = [cwt_freqs] # array conversion - freqs = np.asarray(freqs) + cwt_freqs = np.asarray(cwt_freqs) # check order for multiple frequencies - if len(freqs) >= 2: - delta_f = np.diff(freqs) + if len(cwt_freqs) >= 2: + delta_f = np.diff(cwt_freqs) increase = np.all(delta_f > 0) assert increase, "Frequencies should be in increasing order" - # frequency mean - if foi is None: - foi_idx = foi_s = foi_e = None - f_vec = freqs - else: - _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), - coords=(freqs,)) - foi_s = _f.sel(freqs=foi[:, 0], method='nearest').data - foi_e = _f.sel(freqs=foi[:, 1], method='nearest').data - foi_idx = np.c_[foi_s, foi_e] - f_vec = freqs[foi_idx].mean(1) - - # build block size indices - if isinstance(block_size, int) and (block_size > 1): - blocks = np.array_split(np.arange(n_epochs), block_size) - else: - blocks = [np.arange(n_epochs)] + # compute frequencies to analyze based on number of samples, + # sampling rate, specified wavelet frequencies and mode + freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) + + # compute the mask based on specified min/max and decimation factor + freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) - n_freqs = len(f_vec) + # the frequency points where we compute connectivity + freqs = freqs[freq_mask] - # compute coherence on blocks of trials - conn = np.zeros((n_epochs, n_pairs, n_freqs, len(times))) + # compute central frequencies + _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), + coords=(freqs,)) + foi_s = _f.sel(freqs=fmin, method='nearest').data + foi_e = _f.sel(freqs=fmax, method='nearest').data + foi_idx = np.c_[foi_s, foi_e] + f_vec = freqs[foi_idx].mean(1) + + if faverage: + n_freqs = len(fmin) + out_freqs = f_vec + else: + n_freqs = len(freqs) + out_freqs = freqs + + conn = dict() + for m in method: + conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) logger.info('Connectivity computation...') # parameters to pass to the connectivity function call_params = dict( method=method, kernel=kernel, foi_idx=foi_idx, source_idx=source_idx, target_idx=target_idx, - mode=mode, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, - mt_bandwidth=mt_bandwidth, + mode=mode, sfreq=sfreq, freqs=freqs, faverage=faverage, + n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, decim=decim, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, verbose=verbose) - for epoch_idx in blocks: - # compute time-resolved spectral connectivity + for epoch_idx in np.arange(n_epochs): + epoch_idx = [epoch_idx] conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) + for m in method: + conn[m][epoch_idx, ...] = np.stack(conn_tr[m], + axis=1).squeeze(axis=-1) - # merge results - conn[epoch_idx, ...] = np.stack(conn_tr, axis=1) + if indices is None: + conn_flat = conn + conn = dict() + for m in method: + this_conn = np.zeros((n_epochs, n_signals, n_signals) + + conn_flat[m].shape[2:], + dtype=conn_flat[m].dtype) + this_conn[:, source_idx, target_idx] = conn_flat[m][:, ...] + this_conn = this_conn.reshape((n_epochs, n_signals ** 2,) + + conn_flat[m].shape[2:]) + conn[m] = this_conn # create a Connectivity container - indices = 'symmetric' - conn = EpochSpectroTemporalConnectivity( - conn, freqs=f_vec, times=times, - n_nodes=n_signals, names=names, indices=indices, method=method, - spec_method=mode, events=events, event_id=event_id, metadata=metadata) + if average: + out = [SpectralConnectivity( + conn[m].mean(axis=0), freqs=out_freqs, n_nodes=n_signals, + names=names, indices=indices, method=method, spec_method=mode, + events=events, event_id=event_id, metadata=metadata) + for m in method] + else: + out = [EpochSpectralConnectivity( + conn[m], freqs=out_freqs, n_nodes=n_signals, names=names, + indices=indices, method=method, spec_method=mode, events=events, + event_id=event_id, metadata=metadata) for m in method] - return conn + logger.info('[Connectivity computation done]') + + # return the object instead of list of length one + if len(out) == 1: + return out[0] + else: + return out def _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, - mode, sfreq, freqs, n_cycles, mt_bandwidth=None, - decim=1, kw_cwt={}, kw_mt={}, n_jobs=1, - verbose=False): - """EStimate time-resolved connectivity for one epoch. + mode, sfreq, freqs, faverage, n_cycles, + mt_bandwidth, decim, kw_cwt, kw_mt, + n_jobs, verbose): + """Estimate time-resolved connectivity for one epoch. - See spectral_connectivity_epoch.""" + See spectral_connectivity_epochs.""" n_pairs = len(source_idx) - # first compute time-frequency decomposition - collapse = None if mode == 'cwt_morlet': out = tfr_array_morlet( data, sfreq, freqs, n_cycles=n_cycles, output='complex', decim=decim, n_jobs=n_jobs, **kw_cwt) + out = np.expand_dims(out, axis=2) # same dims with multitaper + weights = None elif mode == 'multitaper': - # In case multiple values are provided for mt_bandwidth - # the MT decomposition is done separatedly for each - # Frequency center - if isinstance(mt_bandwidth, (list, tuple, np.ndarray)): - # Arrays freqs, n_cycles, mt_bandwidth should have the same size - assert len(freqs) == len(n_cycles) == len(mt_bandwidth) - out = [] - for f_c, n_c, mt in zip(freqs, n_cycles, mt_bandwidth): - out += [tfr_array_multitaper( - data, sfreq, [f_c], n_cycles=float(n_c), time_bandwidth=mt, - output='complex', decim=decim, n_jobs=n_jobs, **kw_mt)] - out = np.stack(out, axis=3).squeeze() - elif isinstance(mt_bandwidth, (type(None), int, float)): - out = tfr_array_multitaper( - data, sfreq, freqs, n_cycles=n_cycles, - time_bandwidth=mt_bandwidth, output='complex', decim=decim, - n_jobs=n_jobs, **kw_mt) - collapse = True - if out.ndim == 5: # newest MNE-Python - collapse = -3 - - # get the supported connectivity function - conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs}[method] - - # computes conn across trials - # TODO: This is wrong -- it averages in the complex domain (over tapers). - # What it *should* do is compute the conn for each taper, then average - # (see below). - if collapse is not None: - out = np.mean(out, axis=collapse) - this_conn = conn_func(out, kernel, foi_idx, source_idx, target_idx, - n_jobs=n_jobs, verbose=verbose, total=n_pairs) - # This is where it should go, but the regression test fails... - # if collapse is not None: - # this_conn = [c.mean(axis=collapse) for c in this_conn] + out = tfr_array_multitaper( + data, sfreq, freqs, n_cycles=n_cycles, + time_bandwidth=mt_bandwidth, output='complex', decim=decim, + n_jobs=n_jobs, **kw_mt) + if isinstance(n_cycles, (int, float)): + n_cycles = [n_cycles] * len(freqs) + mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 + n_tapers = int(np.floor(mt_bandwidth - 1)) + weights = np.zeros((n_tapers, len(freqs), out.shape[-1])) + for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): + window_length = np.arange(0., n_c / float(f), 1.0 / sfreq).shape[0] + half_nbw = mt_bandwidth / 2. + n_tapers = int(np.floor(mt_bandwidth - 1)) + _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, + sym=False) + weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) + # weights have shape (n_tapers, n_freqs, n_times) + else: + raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") + + # compute for each connectivity method + this_conn = {} + conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, + 'wpli': _wpli} + for m in method: + c_func = conn_func[m] + this_conn[m] = c_func(out, kernel, foi_idx, source_idx, + target_idx, n_jobs=n_jobs, + verbose=verbose, total=n_pairs, + faverage=faverage, weights=weights) + return this_conn @@ -289,24 +457,35 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### ############################################################################### -def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total): - """Pairwise coherence.""" - # auto spectra (faster that w * w.conj()) - s_auto = w.real ** 2 + w.imag ** 2 +def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage, weights): + """Pairwise coherence. - # smooth the auto spectra - s_auto = _smooth_spectra(s_auto, kernel) + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" + + if weights is not None: + psd = weights * w + psd = psd * np.conj(psd) + psd = psd.real.sum(axis=2) + psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) + else: + psd = w.real ** 2 + w.imag ** 2 + psd = np.squeeze(psd, axis=2) + + # smooth the psd + psd = _smooth_spectra(psd, kernel) - # define the pairwise coherence def pairwise_coh(w_x, w_y): - # computes the coherence - s_xy = w[:, w_y] * np.conj(w[:, w_x]) + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) s_xy = _smooth_spectra(s_xy, kernel) - s_xx = s_auto[:, w_x] - s_yy = s_auto[:, w_y] - out = np.abs(s_xy) ** 2 / (s_xx * s_yy) + s_xx = psd[:, w_x] + s_yy = psd[:, w_y] + out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ + np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray): + if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) else: return out @@ -315,24 +494,24 @@ def pairwise_coh(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_coh, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial coherence return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) -def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total): - """Pairwise phase-locking value.""" - # define the pairwise plv +def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage, weights): + """Pairwise phase-locking value. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" def pairwise_plv(w_x, w_y): - # computes the plv - s_xy = w[:, w_y] * np.conj(w[:, w_x]) - # complex exponential of phase differences + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) exp_dphi = s_xy / np.abs(s_xy) - # smooth e^(-i*\delta\phi) exp_dphi = _smooth_spectra(exp_dphi, kernel) - # computes plv - out = np.abs(exp_dphi) + # mean over time + exp_dphi_mean = exp_dphi.mean(axis=-1, keepdims=True) + out = np.abs(exp_dphi_mean) # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray): + if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) else: return out @@ -341,18 +520,65 @@ def pairwise_plv(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_plv, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial coherence return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) -def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total): +def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage, weights): + """Pairwise phase-lag index. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" + def pairwise_pli(w_x, w_y): + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) + s_xy = _smooth_spectra(s_xy, kernel) + out = np.abs(np.mean(np.sign(np.imag(s_xy)), + axis=-1, keepdims=True)) + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + return _foi_average(out, foi_idx) + else: + return out + + # define the function to compute in parallel + parallel, p_fun, n_jobs = parallel_func( + pairwise_pli, n_jobs=n_jobs, verbose=verbose, total=total) + + return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + + +def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage, weights): + """Pairwise weighted phase-lag index. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" + def pairwise_wpli(w_x, w_y): + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) + s_xy = _smooth_spectra(s_xy, kernel) + con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) + out = con_num / con_den + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + return _foi_average(out, foi_idx) + else: + return out + + # define the function to compute in parallel + parallel, p_fun, n_jobs = parallel_func( + pairwise_wpli, n_jobs=n_jobs, verbose=verbose, total=total) + + return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + + +def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage, weights): """Pairwise cross-spectra.""" - # define the pairwise cross-spectra def pairwise_cs(w_x, w_y): - # computes the cross-spectra - out = w[:, w_x] * np.conj(w[:, w_y]) + out = _compute_csd(w[:, w_y], w[:, w_x], weights) out = _smooth_spectra(out, kernel) - if foi_idx is not None: + if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) else: return out @@ -361,10 +587,20 @@ def pairwise_cs(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_cs, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial coherence return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _compute_csd(x, y, weights): + """Compute cross spectral density of signals x and y.""" + if weights is not None: + s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3) + else: + s_xy = x * np.conj(y) + s_xy = np.squeeze(s_xy, axis=-3) + return s_xy + + def _foi_average(conn, foi_idx): """Average inside frequency bands. @@ -392,5 +628,6 @@ def _foi_average(conn, foi_idx): # compute average conn_f = np.zeros(sh, dtype=conn.dtype) for n_f, (f_s, f_e) in enumerate(foi_idx): + f_e += 1 if f_s == f_e else f_e conn_f[..., n_f, :] = conn[..., f_s:f_e, :].mean(-2) return conn_f diff --git a/requirements.txt b/requirements.txt index 0f5821b0..067ea5c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ h5netcdf tqdm matplotlib qtpy -PySide6!=6.3.0 +PySide6!=6.3.0,!=6.4.0,!=6.4.0.1 sip pyvista>=0.30 pyvistaqt>=0.4 diff --git a/requirements_doc.txt b/requirements_doc.txt index 49330528..51d4fa53 100644 --- a/requirements_doc.txt +++ b/requirements_doc.txt @@ -6,7 +6,7 @@ sphinx-copybutton numpydoc nibabel nilearn -pydata-sphinx-theme +https://github.com/pydata/pydata-sphinx-theme/archive/cef3e724e15852fc2a84bee256c457c9497834b8.zip typing-extensions sphinx-autodoc-typehints sphinxcontrib-bibtex diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index e656f5af..78586360 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -2,7 +2,7 @@ echo "Installing setuptools and sphinx" python -m pip install --progress-bar off --upgrade "pip!=20.3.0" setuptools wheel -python -m pip install --upgrade --progress-bar off --pre sphinx +python -m pip install --upgrade --progress-bar off sphinx echo "Installing doc build dependencies" python -m pip uninstall -y pydata-sphinx-theme