From 647dd7124dd7a8d2856c039638ee2cb205b74c01 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Sep 2024 21:48:00 +0000 Subject: [PATCH 1/8] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.4 → v0.6.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.4...v0.6.5) - [github.com/astral-sh/ruff-pre-commit: v0.6.4 → v0.6.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.4...v0.6.5) --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d50d77a6..83592a5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne_connectivity - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff name: ruff lint mne_connectivity @@ -10,7 +10,7 @@ repos: # Ruff tutorials and examples - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff name: ruff lint tutorials and examples From 378a18671aa9d8bbd67b28fc3e2c9865f1f16541 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 16 Sep 2024 20:47:12 -0400 Subject: [PATCH 2/8] FIX: TransformerMixin --- mne_connectivity/decoding/decomposition.py | 3 ++- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 30a78c26..1c6a74bd 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -5,10 +5,11 @@ # License: BSD (3-clause) +from sklearn.base import TransformerMixin + import numpy as np from mne import Info from mne._fiff.pick import pick_info -from mne.decoding.mixin import TransformerMixin from mne.defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from mne.evoked import EvokedArray from mne.fixes import BaseEstimator diff --git a/pyproject.toml b/pyproject.toml index ab8b6ef8..090e973f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ 'scipy >= 1.4.0', 'tqdm', 'xarray >= 2023.11.0', + 'scikit-learn >= 1.2' ] description = 'mne-connectivity: A module for connectivity data analysis with MNE.' dynamic = ["version"] From 24b80d5476158e8921d832a565291255577bb61c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 00:47:57 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne_connectivity/decoding/decomposition.py | 3 +-- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 1c6a74bd..9584bfcc 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -5,8 +5,6 @@ # License: BSD (3-clause) -from sklearn.base import TransformerMixin - import numpy as np from mne import Info from mne._fiff.pick import pick_info @@ -16,6 +14,7 @@ from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper from mne.utils import _check_option, _validate_type from mne.viz.utils import plt_show +from sklearn.base import TransformerMixin from ..spectral.epochs_multivariate import ( _CaCohEst, diff --git a/pyproject.toml b/pyproject.toml index 090e973f..93d3ecfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,10 @@ dependencies = [ 'netCDF4 >= 1.6.5', 'numpy >= 1.21', 'pandas >= 1.3.2', + 'scikit-learn >= 1.2', 'scipy >= 1.4.0', 'tqdm', 'xarray >= 2023.11.0', - 'scikit-learn >= 1.2' ] description = 'mne-connectivity: A module for connectivity data analysis with MNE.' dynamic = ["version"] From 685b04a53901607a818d0d9a94b8a123953f1fd1 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 16 Sep 2024 20:51:36 -0400 Subject: [PATCH 4/8] FIX: More --- mne_connectivity/decoding/decomposition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 9584bfcc..e0019e07 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -10,11 +10,10 @@ from mne._fiff.pick import pick_info from mne.defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from mne.evoked import EvokedArray -from mne.fixes import BaseEstimator from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper from mne.utils import _check_option, _validate_type from mne.viz.utils import plt_show -from sklearn.base import TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin from ..spectral.epochs_multivariate import ( _CaCohEst, From 446a0ead328ccabb043a568028fc4126d7b3d938 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 16 Sep 2024 20:58:26 -0400 Subject: [PATCH 5/8] FIX: More --- mne_connectivity/decoding/decomposition.py | 1 + mne_connectivity/decoding/tests/test_decomposition.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index e0019e07..cff12cdc 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -221,6 +221,7 @@ def __init__( # n_jobs and verbose will be checked downstream # Store inputs + self.method = method self.info = info self._conn_estimator_class = _conn_estimator_class self._indices = _indices # uses getter/setter for public parameter diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index ba46b160..dfd61d89 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -174,9 +174,9 @@ def test_spectral_decomposition(method, mode): epochs_transformed_2 = decomp_class_2.transform( X=epochs[: n_epochs // 2].get_data() ) - assert_allclose(epochs_transformed, epochs_transformed_2, atol=1e-9) - assert_allclose(decomp_class.filters_, decomp_class_2.filters_, atol=1e-9) - assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_, atol=1e-9) + assert_allclose(epochs_transformed, epochs_transformed_2, atol=1e-8) + assert_allclose(decomp_class.filters_, decomp_class_2.filters_, atol=1e-8) + assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_, atol=1e-8) # TEST FITTING ON ONE PIECE OF DATA AND TRANSFORMING ANOTHER con_mv_class_unseen_data = spectral_connectivity_epochs( From bba335aca0e4301f323941c41f42369fa01cbfef Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 16 Sep 2024 21:17:50 -0400 Subject: [PATCH 6/8] FIX: Fix --- doc/conf.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index 301bb9d2..9c9a39d2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -360,3 +360,30 @@ suppress_warnings = [ "config.cache", # our rebuild is okay ] + + +def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): + """Fix sklearn docstrings because they use autolink and we do not.""" + if ( + name.startswith("mne.decoding.") or name.startswith("mne.preprocessing.Xdawn") + ) and name.endswith( + ( + ".get_metadata_routing", + ".fit", + ".fit_transform", + ".set_output", + ".transform", + ) + ): + if ":Parameters:" in lines: + loc = lines.index(":Parameters:") + else: + loc = lines.index(":Returns:") + lines.insert(loc, "") + lines.insert(loc, ".. default-role:: autolink") + lines.insert(loc, "") + + +def setup(app): + """Set up the Sphinx app.""" + app.connect("autodoc-process-docstring", fix_sklearn_inherited_docstrings) From 15be695aa8de0cd1b325936bca3eee7d5faf8f39 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 16 Sep 2024 21:18:22 -0400 Subject: [PATCH 7/8] FIX: Fix --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index 9c9a39d2..669e395f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -365,7 +365,7 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): """Fix sklearn docstrings because they use autolink and we do not.""" if ( - name.startswith("mne.decoding.") or name.startswith("mne.preprocessing.Xdawn") + name.startswith("mne_connectivity.decoding.") ) and name.endswith( ( ".get_metadata_routing", From 49b074a72f6d4a292b278229c7f815fde30f6473 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 17 Sep 2024 12:06:06 +0200 Subject: [PATCH 8/8] Add sklearn refs --- doc/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index 669e395f..4f59ab4f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -149,6 +149,9 @@ "Axes3D": "mpl_toolkits.mplot3d.axes3d.Axes3D", "PolarAxes": "matplotlib.projections.polar.PolarAxes", "ColorbarBase": "matplotlib.colorbar.ColorbarBase", + # sklearn + "MetadataRequest": "sklearn.utils.metadata_routing.MetadataRequest", + "estimator": "sklearn.base.BaseEstimator", # joblib "joblib.Parallel": "joblib.Parallel", # nibabel