diff --git a/.circleci/config.yml b/.circleci/config.yml index 644fd8b31b7..26b9f600e3c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -218,6 +218,9 @@ jobs: - restore_cache: keys: - data-cache-phantom-kit + - restore_cache: + keys: + - data-cache-ds004388 - run: name: Get data # This limit could be increased, but this is helpful for finding slow ones @@ -252,7 +255,7 @@ jobs: name: Check sphinx log for warnings (which are treated as errors) when: always command: | - ! grep "^.* WARNING: .*$" sphinx_log.txt + ! grep "^.*\(WARNING\|ERROR\): " sphinx_log.txt - run: name: Show profiling output when: always @@ -393,6 +396,10 @@ jobs: key: data-cache-phantom-kit paths: - ~/mne_data/MNE-phantom-KIT-data # (1 G) + - save_cache: + key: data-cache-ds004388 + paths: + - ~/mne_data/ds004388 # (1.8 G) linkcheck: diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index d8a99200783..18543b854d0 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -21,4 +21,4 @@ jobs: - run: pip install --upgrade towncrier pygithub gitpython numpy - run: python ./.github/actions/rename_towncrier/rename_towncrier.py - run: python ./tools/dev/ensure_headers.py - - uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c + - uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8975e72784b..298908cdc65 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -93,37 +93,36 @@ jobs: with: qt: true pyvista: false + wm: false # Python (if pip) - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} if: startswith(matrix.kind, 'pip') # Python (if conda) - - name: Remove numba and dipy - run: | # TODO: Remove when numba 0.59 and dipy 1.8 land on conda-forge - sed -i '/numba/d' environment.yml - sed -i '/dipy/d' environment.yml - sed -i 's/- mne$/- mne-base/' environment.yml - if: matrix.os == 'ubuntu-latest' && startswith(matrix.kind, 'conda') && matrix.python == '3.12' + - name: Fixes for conda + run: | + # For some reason on Linux we get crashes + if [[ "$RUNNER_OS" == "Linux" ]]; then + sed -i "/numba/d" environment.yml + elif [[ "$RUNNER_OS" == "macOS" ]]; then + sed -i "" "s/ - PySide6 .*/ - PySide6 <6.8/g" environment.yml + fi + if: matrix.kind == 'conda' || matrix.kind == 'mamba' - uses: mamba-org/setup-micromamba@v2 with: environment-file: ${{ env.CONDA_ENV }} environment-name: mne create-args: >- python=${{ env.PYTHON_VERSION }} - mamba - nomkl if: ${{ !startswith(matrix.kind, 'pip') }} - # Make sure we have the right Python - - run: python -c "import platform; assert platform.machine() == 'arm64', platform.machine()" - if: matrix.os == 'macos-14' - - run: ./tools/github_actions_dependencies.sh + - run: bash ./tools/github_actions_dependencies.sh # Minimal commands on Linux (macOS stalls) - - run: ./tools/get_minimal_commands.sh + - run: bash ./tools/get_minimal_commands.sh if: startswith(matrix.os, 'ubuntu') && matrix.kind != 'minimal' && matrix.kind != 'old' - - run: ./tools/github_actions_infos.sh + - run: bash ./tools/github_actions_infos.sh # Check Qt - - run: ./tools/check_qt_import.sh $MNE_QT_BACKEND + - run: bash ./tools/check_qt_import.sh $MNE_QT_BACKEND if: env.MNE_QT_BACKEND != '' - name: Run tests with no testing data run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/ @@ -133,8 +132,8 @@ jobs: with: key: ${{ env.TESTING_VERSION }} path: ~/mne_data - - run: ./tools/github_actions_download.sh - - run: ./tools/github_actions_test.sh + - run: bash ./tools/github_actions_download.sh + - run: bash ./tools/github_actions_test.sh # for some reason on macOS we need to run "bash X" in order for a failed test run to show up - uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd327428ecf..cb769988655 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.9.1 hooks: - id: ruff name: ruff lint mne @@ -70,19 +70,19 @@ repos: name: Copy dependency changes from pyproject.toml to environment.yml language: python entry: ./tools/hooks/update_environment_file.py - files: pyproject.toml + files: '^(pyproject.toml|tools/hooks/update_environment_file.py)$' - repo: local hooks: - id: dependency-sync name: Copy core dependencies from pyproject.toml to README.rst language: python entry: ./tools/hooks/sync_dependencies.py - files: pyproject.toml + files: '^(pyproject.toml|tools/hooks/sync_dependencies.py)$' additional_dependencies: ["mne==1.9.0"] # zizmor - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v0.10.0 + rev: v1.1.1 hooks: - id: zizmor diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3ca4177174f..7149edac50b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -88,6 +88,7 @@ stages: variables: DISPLAY: ':99' OPENBLAS_NUM_THREADS: '1' + MNE_TEST_ALLOW_SKIP: '^.*(PySide6 causes segfaults).*$' steps: - bash: | set -e @@ -111,7 +112,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip - python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1" + python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1" pandas neo pymatreader antio defusedxml python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] displayName: 'Install dependencies with pip' @@ -132,7 +133,7 @@ stages: displayName: 'Cache testing data' - script: python -c "import mne; mne.datasets.testing.data_path(verbose=True)" displayName: 'Get test data' - - script: pytest --error-for-skips -m "ultraslowtest or pgtest" --tb=short --cov=mne --cov-report=xml --cov-report=html -vv mne + - script: pytest -m "ultraslowtest or pgtest" --tb=short --cov=mne --cov-report=xml -vv mne displayName: 'slow and mne-qt-browser tests' # Coverage - bash: bash <(curl -s https://codecov.io/bash) @@ -144,11 +145,9 @@ stages: testRunTitle: 'Publish test results for $(Agent.JobName)' failTaskOnFailedTests: true condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 + - task: PublishCodeCoverageResults@2 inputs: - codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' - job: Qt pool: @@ -156,7 +155,8 @@ stages: variables: DISPLAY: ':99' OPENBLAS_NUM_THREADS: '1' - TEST_OPTIONS: "--tb=short --cov=mne --cov-report=xml --cov-report=html --cov-append -vv mne/viz/_brain mne/viz/backends mne/viz/tests/test_evoked.py mne/gui mne/report" + TEST_OPTIONS: "--tb=short --cov=mne --cov-report=xml --cov-append -vv mne/viz/_brain mne/viz/backends mne/viz/tests/test_evoked.py mne/gui mne/report" + MNE_TEST_ALLOW_SKIP: '^.*(PySide6 causes segfaults).*$' steps: - bash: ./tools/setup_xvfb.sh displayName: 'Install Ubuntu dependencies' @@ -192,6 +192,7 @@ stages: set -eo pipefail python -m pip install PyQt6 LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" + displayName: 'Check Qt import' - bash: | set -eo pipefail mne sys_info -pd @@ -226,11 +227,9 @@ stages: testRunTitle: 'Publish test results for $(Agent.JobName)' failTaskOnFailedTests: true condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 + - task: PublishCodeCoverageResults@2 inputs: - codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' - job: Windows pool: @@ -244,7 +243,7 @@ stages: PYTHONIOENCODING: 'utf-8' AZURE_CI_WINDOWS: 'true' PYTHON_ARCH: 'x64' - timeoutInMinutes: 75 + timeoutInMinutes: 80 strategy: maxParallel: 4 matrix: @@ -285,7 +284,7 @@ stages: displayName: 'Cache testing data' - script: python -c "import mne; mne.datasets.testing.data_path(verbose=True)" displayName: 'Get test data' - - script: pytest -m "not (slowtest or pgtest)" --tb=short --cov=mne --cov-report=xml --cov-report=html -vv mne + - script: pytest -m "not (slowtest or pgtest)" --tb=short --cov=mne --cov-report=xml -vv mne displayName: 'Run tests' - bash: bash <(curl -s https://codecov.io/bash) displayName: 'Codecov' @@ -296,8 +295,6 @@ stages: testRunTitle: 'Publish test results for $(Agent.JobName) $(TEST_MODE) $(PYTHON_VERSION)' failTaskOnFailedTests: true condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 + - task: PublishCodeCoverageResults@2 inputs: - codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' diff --git a/doc/api/datasets.rst b/doc/api/datasets.rst index 2b2c92c8654..87730fbd717 100644 --- a/doc/api/datasets.rst +++ b/doc/api/datasets.rst @@ -18,6 +18,7 @@ Datasets brainstorm.bst_auditory.data_path brainstorm.bst_resting.data_path brainstorm.bst_raw.data_path + default_path eegbci.load_data eegbci.standardize fetch_aparc_sub_parcellation diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 86ad3aca910..9fe3f995cc4 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -116,6 +116,7 @@ Projections: read_ica_eeglab read_fine_calibration write_fine_calibration + apply_pca_obs :py:mod:`mne.preprocessing.nirs`: diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index 8923920bdba..a9ab2c34268 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -31,6 +31,8 @@ Functions that operate on mne-python objects: .. autosummary:: :toctree: ../generated/ + combine_spectrum + combine_tfr csd_tfr csd_fourier csd_multitaper diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst new file mode 100644 index 00000000000..3f32dbd23e5 --- /dev/null +++ b/doc/changes/devel/12656.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (:attr:`raw.first_time `) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/doc/changes/devel/12828.bugfix.rst b/doc/changes/devel/12828.bugfix.rst new file mode 100644 index 00000000000..707385ac698 --- /dev/null +++ b/doc/changes/devel/12828.bugfix.rst @@ -0,0 +1 @@ +Fixed behavior of :func:`mne.viz.plot_source_estimates` where the ``title`` was not displayed properly, by :newcontrib:`Shristi Baral`. diff --git a/doc/changes/devel/12910.newfeature.rst b/doc/changes/devel/12910.newfeature.rst new file mode 100644 index 00000000000..95605c11017 --- /dev/null +++ b/doc/changes/devel/12910.newfeature.rst @@ -0,0 +1 @@ +Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13037.newfeature.rst b/doc/changes/devel/13037.newfeature.rst new file mode 100644 index 00000000000..3b28e2294ab --- /dev/null +++ b/doc/changes/devel/13037.newfeature.rst @@ -0,0 +1 @@ +Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Steinn Hauser Magnusson`. diff --git a/doc/changes/devel/13048.bugfix.rst b/doc/changes/devel/13048.bugfix.rst new file mode 100644 index 00000000000..8f0fe46f3c7 --- /dev/null +++ b/doc/changes/devel/13048.bugfix.rst @@ -0,0 +1 @@ +Fix input boxes for the max value not showing when plotting fieldlines with :func:`~mne.viz.plot_evoked_field` when ``show_density=False``, by `Marijn van Vliet`_. diff --git a/doc/changes/devel/13054.newfeature.rst b/doc/changes/devel/13054.newfeature.rst new file mode 100644 index 00000000000..3c89290e7fe --- /dev/null +++ b/doc/changes/devel/13054.newfeature.rst @@ -0,0 +1 @@ +Added :func:`mne.time_frequency.combine_tfr` to allow combining TFRs across tapers, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13056.bugfix.rst b/doc/changes/devel/13056.bugfix.rst new file mode 100644 index 00000000000..2a7919de289 --- /dev/null +++ b/doc/changes/devel/13056.bugfix.rst @@ -0,0 +1 @@ +Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_. diff --git a/doc/changes/devel/13058.newfeature.rst b/doc/changes/devel/13058.newfeature.rst new file mode 100644 index 00000000000..bbd01fa4552 --- /dev/null +++ b/doc/changes/devel/13058.newfeature.rst @@ -0,0 +1 @@ +Add the function :func:`mne.time_frequency.combine_spectrum` for combining data across :class:`mne.time_frequency.Spectrum` objects, and allow :func:`mne.grand_average` to operate on :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13062.bugfix.rst b/doc/changes/devel/13062.bugfix.rst new file mode 100644 index 00000000000..9e01fc4c835 --- /dev/null +++ b/doc/changes/devel/13062.bugfix.rst @@ -0,0 +1 @@ +Fix computation of time intervals in :func:`mne.preprocessing.compute_fine_calibration` by `Eric Larson`_. diff --git a/doc/changes/devel/13063.bugfix.rst b/doc/changes/devel/13063.bugfix.rst new file mode 100644 index 00000000000..76eba2032a1 --- /dev/null +++ b/doc/changes/devel/13063.bugfix.rst @@ -0,0 +1 @@ +Fix bug in the colorbars created by :func:`mne.viz.plot_evoked_topomap` by `Santeri Ruuskanen`_. \ No newline at end of file diff --git a/doc/changes/devel/13067.bugfix.rst b/doc/changes/devel/13067.bugfix.rst new file mode 100644 index 00000000000..237df7623d5 --- /dev/null +++ b/doc/changes/devel/13067.bugfix.rst @@ -0,0 +1 @@ +Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13069.bugfix.rst b/doc/changes/devel/13069.bugfix.rst new file mode 100644 index 00000000000..7c23221c8df --- /dev/null +++ b/doc/changes/devel/13069.bugfix.rst @@ -0,0 +1 @@ +Fix bug cause by unnecessary assertion when loading mixed frequency EDFs without preloading :func:`mne.io.read_raw_edf` by `Simon Kern`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 3dfc742b3b3..eb444c5e594 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -73,6 +73,7 @@ .. _Eberhard Eich: https://github.com/ebeich .. _Eduard Ort: https://github.com/eort .. _Emily Stephen: https://github.com/emilyps14 +.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey .. _Enrico Varano: https://github.com/enricovara/ .. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt .. _Eric Larson: https://larsoner.com @@ -273,6 +274,7 @@ .. _Senwen Deng: https://snwn.de .. _Seyed Yahya Shirazi: https://neuromechanist.github.io .. _Sheraz Khan: https://github.com/SherazKhan +.. _Shristi Baral: https://github.com/shristibaral .. _Silvia Cotroneo: https://github.com/sfc-neuro .. _Simeon Wong: https://github.com/dtxe .. _Simon Kern: https://skjerns.de @@ -283,6 +285,7 @@ .. _Stanislas Chambon: https://github.com/Slasnista .. _Stefan Appelhoff: https://stefanappelhoff.com .. _Stefan Repplinger: https://github.com/stfnrpplngr +.. _Steinn Hauser Magnusson: https://github.com/steinnhauser .. _Steven Bethard: https://github.com/bethard .. _Steven Bierer: https://github.com/neurolaunch .. _Steven Gutstein: https://github.com/smgutstein diff --git a/doc/changes/v1.7.rst b/doc/changes/v1.7.rst index dfd3129a18d..6b118612541 100644 --- a/doc/changes/v1.7.rst +++ b/doc/changes/v1.7.rst @@ -75,12 +75,12 @@ Bugfixes - Fix validation of ``ch_type`` in :func:`mne.preprocessing.annotate_muscle_zscore`, by `Mathieu Scheltienne`_. (`#12444 `__) - Fix errant redundant use of ``BIDSPath.split`` when writing split raw and epochs data, by `Eric Larson`_. (`#12451 `__) - Disable config parser interpolation when reading BrainVision files, which allows using the percent sign as a regular character in channel units, by `Clemens Brunner`_. (`#12456 `__) -- - Fix the default color of :meth:`mne.viz.Brain.add_text` to properly contrast with the figure background color, by `Marijn van Vliet`_. (`#12470 `__) -- - Changed default ECoG and sEEG electrode sizes in brain plots to better reflect real world sizes, by `Liberty Hamilton`_ (`#12474 `__) +- Fix the default color of :meth:`mne.viz.Brain.add_text` to properly contrast with the figure background color, by `Marijn van Vliet`_. (`#12470 `__) +- Changed default ECoG and sEEG electrode sizes in brain plots to better reflect real world sizes, by `Liberty Hamilton`_ (`#12474 `__) - Fixed bugs with handling of rank in :class:`mne.decoding.CSP`, by `Eric Larson`_. (`#12476 `__) -- - Fix reading segmented recordings with :func:`mne.io.read_raw_eyelink` by `Dominik Welke`_. (`#12481 `__) +- Fix reading segmented recordings with :func:`mne.io.read_raw_eyelink` by `Dominik Welke`_. (`#12481 `__) - Improve compatibility with other Qt-based GUIs by handling theme icons better, by `Eric Larson`_. (`#12483 `__) -- - Fix problem caused by onsets with NaN values using :func:`mne.io.read_raw_eeglab` by `Jacob Woessner`_ (`#12484 `__) +- Fix problem caused by onsets with NaN values using :func:`mne.io.read_raw_eeglab` by `Jacob Woessner`_ (`#12484 `__) - Fix cleaning of channel names for non vectorview or CTF dataset including whitespaces or dash in their channel names, by `Mathieu Scheltienne`_. (`#12489 `__) - Fix bug with :meth:`mne.preprocessing.ICA.plot_sources` for ``evoked`` data where the legend contained too many entries, by `Eric Larson`_. (`#12498 `__) diff --git a/doc/changes/v1.9.rst b/doc/changes/v1.9.rst index 17a3a2ba1fe..0c6f7c1fddc 100644 --- a/doc/changes/v1.9.rst +++ b/doc/changes/v1.9.rst @@ -1,12 +1,12 @@ .. _changes_1_9_0: -1.9.0 (2024-12-18) -================== +Version 1.9.0 (2024-12-18) +========================== Dependencies ------------ -- - Minimum supported dependencies were updated in accordance with SPEC0_, most notably Python 3.10+ is now required. (`#12798 `__) +- Minimum supported dependencies were updated in accordance with SPEC0_, most notably Python 3.10+ is now required. (`#12798 `__) - Importing from ``mne.decoding`` now explicitly requires ``scikit-learn`` to be installed, by `Eric Larson`_. (`#12834 `__) - Compatibility improved for Python 3.13, by `Eric Larson`_. (`#13021 `__) @@ -63,7 +63,7 @@ New features - Add option to :func:`mne.preprocessing.fix_stim_artifact` to use baseline average to flatten TMS pulse artifact by `Fahimeh Mamashli`_ and `Padma Sundaram`_ and `Mohammad Daneshzand`_. (`#6915 `__) - Add support for `dict` type argument ``ref_channels`` to :func:`mne.set_eeg_reference`, to allow flexible re-referencing (e.g. ``raw.set_eeg_reference(ref_channels={'A1': ['A2', 'A3']})`` will set the new A1 data to be ``A1 - mean(A2, A3)``), by `Alex Lepauvre`_ and `Qian Chu`_ and `Daniel McCloy`_. (`#12366 `__) - Add reader for ANT Neuro files in the ``*.cnt`` format with :func:`~mne.io.read_raw_ant`, by `Mathieu Scheltienne`_, `Eric Larson`_ and `Proloy Das`_. (`#12792 `__) -- - Add support for a :class:`mne.transforms.Transform` in the argument ``trans`` of the coregistration GUI called with :func:`mne.gui.coregistration`, by `Mathieu Scheltienne`_. (`#12801 `__) +- Add support for a :class:`mne.transforms.Transform` in the argument ``trans`` of the coregistration GUI called with :func:`mne.gui.coregistration`, by `Mathieu Scheltienne`_. (`#12801 `__) - :meth:`~mne.io.Raw` and :meth:`~mne.Epochs.save` now return the path to the saved file(s), by `Victor Ferat`_. (`#12811 `__) - :func:`mne.channels.read_custom_montage` may now read a newer version of the ``.elc`` ASA Electrode file format, by `Stefan Appelhoff`_. (`#12830 `__) - Added the ``title`` argument to :func:`mne.viz.create_3d_figure`, and diff --git a/doc/conf.py b/doc/conf.py index 7dd6ec90d4f..f1b771571d6 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -355,6 +355,7 @@ "n_frequencies", "n_tests", "n_samples", + "n_peaks", "n_permutations", "nchan", "n_points", @@ -666,6 +667,10 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): r"https://scholar.google.com/scholar\?cites=12188330066413208874&as_ylo=2014", r"https://scholar.google.com/scholar\?cites=1521584321377182930&as_ylo=2013", "https://www.research.chop.edu/imaging", + "http://prdownloads.sourceforge.net/optipng/optipng-0.7.8-win64.zip?download", + "https://sourceforge.net/projects/aespa/files/", + "https://sourceforge.net/projects/ezwinports/files/", + "https://www.mathworks.com/products/compiler/matlab-runtime.html", # 500 server error "https://openwetware.org/wiki/Beauchamp:FreeSurfer", # 503 Server error @@ -688,6 +693,7 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): # SSL problems sometimes "http://ilabs.washington.edu", "https://psychophysiology.cpmc.columbia.edu", + "https://erc.easme-web.eu", ] linkcheck_anchors = False # saves a bit of time linkcheck_timeout = 15 # some can be quite slow @@ -1284,7 +1290,7 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): rst_prolog += f""" .. |{icon}| raw:: html - + """ rst_prolog += """ diff --git a/doc/development/contributing.rst b/doc/development/contributing.rst index 07c28f55d3d..011fd3c11f4 100644 --- a/doc/development/contributing.rst +++ b/doc/development/contributing.rst @@ -1114,6 +1114,6 @@ it can serve as a useful example of what to expect from the PR review process. .. optipng .. _optipng: http://optipng.sourceforge.net/ -.. _optipng for Windows: http://prdownloads.sourceforge.net/optipng/optipng-0.7.7-win32.zip?download +.. _optipng for Windows: http://prdownloads.sourceforge.net/optipng/optipng-0.7.8-win64.zip?download .. include:: ../links.inc diff --git a/doc/install/installers.rst b/doc/install/installers.rst index 5b7eeba5203..533c0207963 100644 --- a/doc/install/installers.rst +++ b/doc/install/installers.rst @@ -86,7 +86,7 @@ Platform-specific installers .. We have to use a button-link here because button-ref doesn't properly nested parse the inline code - .. button-link:: ./ides.html + .. button-link:: ides.html :ref-type: ref :color: success :shadow: diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, diff --git a/doc/sphinxext/credit_tools.py b/doc/sphinxext/credit_tools.py index 708dcf00ce8..e22bd0b5530 100644 --- a/doc/sphinxext/credit_tools.py +++ b/doc/sphinxext/credit_tools.py @@ -169,7 +169,7 @@ def generate_credit_rst(app=None, *, verbose=False): if author["e"] is not None: if author["e"] not in name_map: unknown_emails.add( - f'{author["e"].ljust(29)} ' + f"{author['e'].ljust(29)} " "https://github.com/mne-tools/mne-python/pull/" f"{commit}/files" ) @@ -178,9 +178,9 @@ def generate_credit_rst(app=None, *, verbose=False): else: name = author["n"] if name in manual_renames: - assert _good_name( - manual_renames[name] - ), f"Bad manual rename: {name}" + assert _good_name(manual_renames[name]), ( + f"Bad manual rename: {name}" + ) name = manual_renames[name] if " " in name: first, last = name.rsplit(" ", maxsplit=1) diff --git a/doc/sphinxext/related_software.py b/doc/sphinxext/related_software.py index ac1b741b9af..ab159b0fcb4 100644 --- a/doc/sphinxext/related_software.py +++ b/doc/sphinxext/related_software.py @@ -163,9 +163,9 @@ def _get_packages() -> dict[str, str]: assert not dups, f"Duplicates in MANUAL_PACKAGES and PYPI_PACKAGES: {sorted(dups)}" # And the installer and PyPI-only should be disjoint: dups = set(PYPI_PACKAGES) & set(packages) - assert ( - not dups - ), f"Duplicates in PYPI_PACKAGES and installer packages: {sorted(dups)}" + assert not dups, ( + f"Duplicates in PYPI_PACKAGES and installer packages: {sorted(dups)}" + ) for name in PYPI_PACKAGES | set(MANUAL_PACKAGES): if name not in packages: packages.append(name) diff --git a/doc/sphinxext/unit_role.py b/doc/sphinxext/unit_role.py index b52665e8321..bf31ddf76c4 100644 --- a/doc/sphinxext/unit_role.py +++ b/doc/sphinxext/unit_role.py @@ -10,8 +10,7 @@ def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # def pass_error_to_sphinx(rawtext, text, lineno, inliner): msg = inliner.reporter.error( - "The :unit: role requires a space-separated number and unit; " - f"got {text}", + f"The :unit: role requires a space-separated number and unit; got {text}", line=lineno, ) prb = inliner.problematic(rawtext, rawtext, msg) diff --git a/environment.yml b/environment.yml index e8e7c28d6cb..e0c458159c4 100644 --- a/environment.yml +++ b/environment.yml @@ -23,11 +23,13 @@ dependencies: - joblib - jupyter - lazy_loader >=0.3 + - mamba - matplotlib >=3.7 - mffpy >=0.5.7 - mne-qt-browser - nibabel - nilearn + - nomkl - numba - numpy >=1.25,<3 - openmeeg >=2.5.5 @@ -57,6 +59,6 @@ dependencies: - trame - trame-vtk - trame-vuetify - - vtk >=9.2 + - vtk =9.3.1=qt_* - wfdb - xlrd diff --git a/examples/inverse/vector_mne_solution.py b/examples/inverse/vector_mne_solution.py index ca953cd2f24..f6ae788c145 100644 --- a/examples/inverse/vector_mne_solution.py +++ b/examples/inverse/vector_mne_solution.py @@ -79,7 +79,7 @@ # inverse was computed with loose=0.2 print( "Absolute cosine similarity between source normals and directions: " - f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}' + f"{np.abs(np.sum(directions * inv['source_nn'][2::3], axis=-1)).mean()}" ) brain_max = stc_max.plot( initial_time=peak_time, diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py new file mode 100755 index 00000000000..a6c6bb3c2ba --- /dev/null +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -0,0 +1,196 @@ +""" +.. _ex-pcaobs: + +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalography) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , +# Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import glob + +import numpy as np + +# %% +# Download sample subject data from OpenNeuro if you haven't already. +# This will download simultaneous EEG and ESG data from a single run of a +# single participant after median nerve stimulation of the left wrist. +import openneuro +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact + +# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data +ds = "ds004388" +target_dir = mne.datasets.default_path() / ds +run_name = "sub-001/eeg/*median_run-03_eeg*.set" +if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) +block_files = glob.glob(str(target_dir / run_name)) +assert len(block_files) == 1 + +# %% +# Define the esg channels (arranged in two patches over the neck and lower back). + +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] + +# Interpolation window in seconds for ESG data to remove stimulation artefact +tstart_esg = -7e-3 +tmax_esg = 7e-3 + +# Define timing of heartbeat epochs in seconds relative to R-peaks +iv_baseline = [-400e-3, -300e-3] +iv_epoch = [-400e-3, 600e-3] + +# %% +# Next, we perform minimal preprocessing including removing the +# stimulation artefact, downsampling and filtering. + +raw = read_raw_eeglab(block_files[0], verbose="error") +raw.set_channel_types(dict(ECG="ecg")) +# Isolate the ESG channels (include the ECG channel for R-peak detection) +raw.pick(esg_chans + ["ECG"]) +# Trim duration and downsample (from 10kHz) to improve example speed +raw.crop(0, 60).load_data().resample(2000) + +# Find trigger timings to remove the stimulation artefact +events, event_dict = events_from_annotations(raw) +trigger_name = "Median - Stimulation" + +fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, +) + +# %% +# Find ECG events and add to the raw structure as event annotations. + +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only + +qrs_event_time = [ + x / raw.info["sfreq"] for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ["qrs"] * len(ecg_event_samples) + +raw.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) + +# %% +# Create evoked response around the detected R-peaks +# before and after cardiac artefact correction. + +events, event_ids = events_from_annotations(raw) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts +raw = mne.preprocessing.apply_pca_obs( + raw, picks=esg_chans, n_jobs=5, qrs_times=raw.times[ecg_event_samples.reshape(-1)] +) + +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +# %% +# Compare evoked responses to assess completeness of artefact removal. + +fig, axes = plt.subplots(1, 1, layout="constrained") +data_before = evoked_before.get_data(units=dict(eeg="uV")).T +data_after = evoked_after.get_data(units=dict(eeg="uV")).T +hs = list() +hs.append(axes.plot(epochs.times, data_before, color="k")[0]) +hs.append(axes.plot(epochs.times, data_after, color="green", label="after")[0]) +axes.set(ylim=[-500, 1000], ylabel="Amplitude (µV)", xlabel="Time (s)") +axes.set(title="ECG artefact removal using PCA-OBS") +axes.legend(hs, ["before", "after"]) +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index 83d1916c6f9..53b7a60dbba 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -5,8 +5,8 @@ Plotting topographic maps of evoked data ======================================== -Load evoked data and plot topomaps for selected time points using multiple -additional options. +Load evoked data and plot topomaps for selected time points using +multiple additional options. """ # Authors: Christian Brodbeck # Tal Linzen diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index ed05ae3ba11..4bcb4bc8c04 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -85,7 +85,7 @@ print("Covariance estimates sorted from best to worst") for c in noise_covs: - print(f'{c["method"]} : {c["loglik"]}') + print(f"{c['method']} : {c['loglik']}") # %% # Show the evoked data: diff --git a/mne/_fiff/_digitization.py b/mne/_fiff/_digitization.py index e55fd5d2dae..eb8b6bc396a 100644 --- a/mne/_fiff/_digitization.py +++ b/mne/_fiff/_digitization.py @@ -328,8 +328,7 @@ def _get_data_as_dict_from_dig(dig, exclude_ref_channel=True): dig_coord_frames = set([FIFF.FIFFV_COORD_HEAD]) if len(dig_coord_frames) != 1: raise RuntimeError( - "Only single coordinate frame in dig is supported, " - f"got {dig_coord_frames}" + f"Only single coordinate frame in dig is supported, got {dig_coord_frames}" ) dig_ch_pos_location = np.array(dig_ch_pos_location) dig_ch_pos_location.shape = (-1, 3) # empty will be (0, 3) diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 629d9a4b0ce..51612824a6a 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -455,7 +455,7 @@ def _check_set(ch, projs, ch_type): for proj in projs: if ch["ch_name"] in proj["data"]["col_names"]: raise RuntimeError( - f'Cannot change channel type for channel {ch["ch_name"]} in ' + f"Cannot change channel type for channel {ch['ch_name']} in " f'projector "{proj["desc"]}"' ) ch["kind"] = new_kind @@ -1867,7 +1867,7 @@ def _check_consistency(self, prepend_error=""): ): raise RuntimeError( f'{prepend_error}info["meas_date"] must be a datetime object in UTC' - f' or None, got {repr(self["meas_date"])!r}' + f" or None, got {repr(self['meas_date'])!r}" ) chs = [ch["ch_name"] for ch in self["chs"]] @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): hi["meas_date"] = _ensure_meas_date_none_or_dt( tuple(int(t) for t in tag.data), ) + if "meas_date" not in hi: + hi["meas_date"] = None info["helium_info"] = hi del hi @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"]) if hi.get("orig_file_guid") is not None: write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"]) - write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) + if hi["meas_date"] is not None: + write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) end_block(fid, FIFF.FIFFB_HELIUM) del hi @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): _write_proc_history(fid, info) -@fill_doc -def write_info(fname, info, data_type=None, reset_range=True): +@verbose +def write_info( + fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None +): """Write measurement info in fif file. Parameters @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True): raw data. reset_range : bool If True, info['chs'][k]['range'] will be set to unity. + %(overwrite)s + %(verbose)s """ - with start_and_end_file(fname) as fid: + with start_and_end_file(fname, overwrite=overwrite) as fid: start_block(fid, FIFF.FIFFB_MEAS) write_meas_info(fid, info, data_type, reset_range) end_block(fid, FIFF.FIFFB_MEAS) @@ -3673,8 +3680,7 @@ def _write_ch_infos(fid, chs, reset_range, ch_names_mapping): # only write new-style channel information if necessary if len(ch_names_mapping): logger.info( - " Writing channel names to FIF truncated to 15 characters " - "with remapping" + " Writing channel names to FIF truncated to 15 characters with remapping" ) for ch in chs: start_block(fid, FIFF.FIFFB_CH_INFO) diff --git a/mne/_fiff/proj.py b/mne/_fiff/proj.py index 0376826138a..d6ec108e34d 100644 --- a/mne/_fiff/proj.py +++ b/mne/_fiff/proj.py @@ -76,7 +76,7 @@ def __repr__(self): # noqa: D105 s += f", active : {self['active']}" s += f", n_channels : {len(self['data']['col_names'])}" if self["explained_var"] is not None: - s += f', exp. var : {self["explained_var"] * 100:0.2f}%' + s += f", exp. var : {self['explained_var'] * 100:0.2f}%" return f"" # speed up info copy by taking advantage of mutability @@ -324,8 +324,7 @@ def apply_proj(self, verbose=None): if all(p["active"] for p in self.info["projs"]): logger.info( - "Projections have already been applied. " - "Setting proj attribute to True." + "Projections have already been applied. Setting proj attribute to True." ) return self @@ -663,9 +662,9 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): for proj in projs: misc = "active" if proj["active"] else " idle" logger.info( - f' {proj["desc"]} ' - f'({proj["data"]["nrow"]} x ' - f'{len(proj["data"]["col_names"])}) {misc}' + f" {proj['desc']} " + f"({proj['data']['nrow']} x " + f"{len(proj['data']['col_names'])}) {misc}" ) return projs @@ -795,8 +794,7 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False if not p["active"] or include_active: if len(p["data"]["col_names"]) != len(np.unique(p["data"]["col_names"])): raise ValueError( - f"Channel name list in projection item {k}" - " contains duplicate items" + f"Channel name list in projection item {k} contains duplicate items" ) # Get the two selection vectors to pick correct elements from @@ -832,7 +830,7 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False ) ): warn( - f'Projection vector {repr(p["desc"])} has been ' + f"Projection vector {repr(p['desc'])} has been " f"reduced to {100 * psize:0.2f}% of its " "original magnitude by subselecting " f"{len(vecsel)}/{orig_n} of the original " diff --git a/mne/_fiff/reference.py b/mne/_fiff/reference.py index e70bf5e36c1..b4c050c096d 100644 --- a/mne/_fiff/reference.py +++ b/mne/_fiff/reference.py @@ -102,7 +102,7 @@ def _check_before_dict_reference(inst, ref_dict): raise TypeError( f"{elem_name.capitalize()}s in the ref_channels dict must be strings. " f"Your dict has {elem_name}s of type " - f'{", ".join(map(lambda x: x.__name__, bad_elem))}.' + f"{', '.join(map(lambda x: x.__name__, bad_elem))}." ) # Check that keys are valid channels and values are lists-of-valid-channels @@ -113,8 +113,8 @@ def _check_before_dict_reference(inst, ref_dict): for elem_name, elem in dict(key=keys, value=values).items(): if bad_elem := elem - ch_set: raise ValueError( - f'ref_channels dict contains invalid {elem_name}(s) ' - f'({", ".join(bad_elem)}) ' + f"ref_channels dict contains invalid {elem_name}(s) " + f"({', '.join(bad_elem)}) " "that are not names of channels in the instance." ) # Check that values are not bad channels diff --git a/mne/_fiff/tag.py b/mne/_fiff/tag.py index abc7d32036b..3fd36454d58 100644 --- a/mne/_fiff/tag.py +++ b/mne/_fiff/tag.py @@ -70,8 +70,7 @@ def _frombuffer_rows(fid, tag_size, dtype=None, shape=None, rlims=None): have_shape = tag_size // item_size if want_shape != have_shape: raise ValueError( - f"Wrong shape specified, requested {want_shape} but got " - f"{have_shape}" + f"Wrong shape specified, requested {want_shape} but got {have_shape}" ) if not len(rlims) == 2: raise ValueError("rlims must have two elements") diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 3e3c150573f..a38ecaade50 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path): gantry_angle = info["gantry_angle"] meas_id = info["meas_id"] - write_info(temp_file, info) + with pytest.raises(FileExistsError, match="Destination file exists"): + write_info(temp_file, info) + write_info(temp_file, info, overwrite=True) info = read_info(temp_file) assert info["proc_history"][0]["creator"] == creator assert info["hpi_meas"][0]["creator"] == creator @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path): info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) fname = tmp_path / "test.fif" with pytest.raises(RuntimeError, match="must be between "): - write_info(fname, info) + write_info(fname, info, overwrite=True) @testing.requires_testing_data @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path): for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"): info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type]) info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03] - write_info(fname, info) + write_info(fname, info, overwrite=True) info2 = read_info(fname) assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD @@ -585,7 +587,7 @@ def test_check_consistency(): info2["subject_info"] = {"height": "bad"} -def _test_anonymize_info(base_info): +def _test_anonymize_info(base_info, tmp_path): """Test that sensitive information can be anonymized.""" pytest.raises(TypeError, anonymize_info, "foo") assert isinstance(base_info, Info) @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt): # exp 4 tests is a supplied daysback delta_t_3 = timedelta(days=223 + 364 * 500) + def _check_equiv(got, want, err_msg): + __tracebackhide__ = True + fname_temp = tmp_path / "test.fif" + assert_object_equal(got, want, err_msg=err_msg) + write_info(fname_temp, got, reset_range=False, overwrite=True) + got = read_info(fname_temp) + # this gets changed on write but that's expected + with got._unlock(): + got["file_id"] = want["file_id"] + assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)") + new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info, err_msg="anon mismatch") + _check_equiv(new_info, exp_info, err_msg="anon mismatch") new_info = anonymize_info(base_info.copy(), keep_his=True) - assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch") + _check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch") new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch") + _check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch") with pytest.raises(RuntimeError, match="anonymize_info generated"): anonymize_info(base_info.copy(), daysback=delta_t_3.days) @@ -726,7 +739,7 @@ def _adjust_back(e_i, dt): new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) else: new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal( + _check_equiv( new_info, exp_info_3, err_msg="meas_date=None daysback mismatch", @@ -734,7 +747,7 @@ def _adjust_back(e_i, dt): with _record_warnings(): # meas_date is None new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch") + _check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch") @pytest.mark.parametrize( @@ -777,8 +790,8 @@ def _complete_info(info): height=2.0, ) info["helium_info"] = dict( - he_level_raw=12.34, - helium_level=45.67, + he_level_raw=np.float32(12.34), + helium_level=np.float32(45.67), meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc), orig_file_guid="e", ) @@ -796,14 +809,13 @@ def _complete_info(info): machid=np.ones(2, int), secs=d[0], usecs=d[1], - date=d, ), experimenter="j", max_info=dict( - max_st=[], - sss_ctc=[], - sss_cal=[], - sss_info=dict(head_pos=None, in_order=8), + max_st=dict(), + sss_ctc=dict(), + sss_cal=dict(), + sss_info=dict(in_order=8), ), date=d, ), @@ -830,8 +842,8 @@ def test_anonymize(tmp_path): # test mne.anonymize_info() events = read_events(event_name) epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None) - _test_anonymize_info(raw.info) - _test_anonymize_info(epochs.info) + _test_anonymize_info(raw.info, tmp_path) + _test_anonymize_info(epochs.info, tmp_path) # test instance methods & I/O roundtrip for inst, keep_his in zip((raw, epochs), (True, False)): @@ -1106,7 +1118,7 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname): meas_info, "_read_extended_ch_info", _read_extended_ch_info ) short_proj_names = [ - f"{name[:13 - bool(len(ref_names))]}-{ni}" + f"{name[: 13 - bool(len(ref_names))]}-{ni}" for ni, name in enumerate(long_proj_names) ] assert raw_read.info["projs"][0]["data"]["col_names"] == short_proj_names diff --git a/mne/_fiff/tests/test_pick.py b/mne/_fiff/tests/test_pick.py index 90830e1d5e5..5d1b24247ab 100644 --- a/mne/_fiff/tests/test_pick.py +++ b/mne/_fiff/tests/test_pick.py @@ -136,7 +136,7 @@ def _channel_type_old(info, idx): else: return t - raise ValueError(f'Unknown channel type for {ch["ch_name"]}') + raise ValueError(f"Unknown channel type for {ch['ch_name']}") def _assert_channel_types(info): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 1fc32f0163e..8486ca13121 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -13,7 +13,7 @@ import numpy as np from scipy.sparse import csc_array, csr_array -from ..utils import _file_like, _validate_type, logger +from ..utils import _check_fname, _file_like, _validate_type, logger from ..utils.numerics import _date_to_julian from .constants import FIFF @@ -277,7 +277,7 @@ def end_block(fid, kind): write_int(fid, FIFF.FIFF_BLOCK_END, kind) -def start_file(fname, id_=None): +def start_file(fname, id_=None, *, overwrite=True): """Open a fif file for writing and writes the compulsory header tags. Parameters @@ -294,6 +294,7 @@ def start_file(fname, id_=None): fid = fname fid.seek(0) else: + fname = _check_fname(fname, overwrite=overwrite) fname = str(fname) if op.splitext(fname)[1].lower() == ".gz": logger.debug("Writing using gzip") @@ -311,9 +312,9 @@ def start_file(fname, id_=None): @contextmanager -def start_and_end_file(fname, id_=None): +def start_and_end_file(fname, id_=None, *, overwrite=True): """Start and (if successfully written) close the file.""" - with start_file(fname, id_=id_) as fid: + with start_file(fname, id_=id_, overwrite=overwrite) as fid: yield fid end_file(fid) # we only hit this line if the yield does not err diff --git a/mne/beamformer/_compute_beamformer.py b/mne/beamformer/_compute_beamformer.py index bb947cdd757..16bedc2c317 100644 --- a/mne/beamformer/_compute_beamformer.py +++ b/mne/beamformer/_compute_beamformer.py @@ -507,13 +507,13 @@ def __repr__(self): # noqa: D105 n_channels, ) if self["pick_ori"] is not None: - out += f', {self["pick_ori"]} ori' + out += f", {self['pick_ori']} ori" if self["weight_norm"] is not None: - out += f', {self["weight_norm"]} norm' + out += f", {self['weight_norm']} norm" if self.get("inversion") is not None: - out += f', {self["inversion"]} inversion' + out += f", {self['inversion']} inversion" if "rank" in self: - out += f', rank {self["rank"]}' + out += f", rank {self['rank']}" out += ">" return out @@ -531,7 +531,7 @@ def save(self, fname, overwrite=False, verbose=None): """ _, write_hdf5 = _import_h5io_funcs() - ending = f'-{self["kind"].lower()}.h5' + ending = f"-{self['kind'].lower()}.h5" check_fname(fname, self["kind"], (ending,)) csd_orig = None try: diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 957dbaf5284..9ae5473e190 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -380,7 +380,7 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): rank = 17 if proj else 20 assert "LCMV" in repr(filters) assert "unknown subject" not in repr(filters) - assert f'{fwd["nsource"]} vert' in repr(filters) + assert f"{fwd['nsource']} vert" in repr(filters) assert "20 ch" in repr(filters) assert f"rank {rank}" in repr(filters) diff --git a/mne/bem.py b/mne/bem.py index d361272fd49..22aa02d2a0d 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -91,7 +91,7 @@ class ConductorModel(dict): def __repr__(self): # noqa: D105 if self["is_sphere"]: - center = ", ".join(f"{x * 1000.:.1f}" for x in self["r0"]) + center = ", ".join(f"{x * 1000.0:.1f}" for x in self["r0"]) rad = self.radius if rad is None: # no radius / MEG only extra = f"Sphere (no layers): r0=[{center}] mm" @@ -538,7 +538,7 @@ def _assert_complete_surface(surf, incomplete="raise"): prop = tot_angle / (2 * np.pi) if np.abs(prop - 1.0) > 1e-5: msg = ( - f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' + f"Surface {_bem_surf_name[surf['id']]} is not complete (sum of " f"solid angles yielded {prop}, should be 1.)" ) _on_missing(incomplete, msg, name="incomplete", error_klass=RuntimeError) @@ -571,7 +571,7 @@ def _check_surface_size(surf): sizes = surf["rr"].max(axis=0) - surf["rr"].min(axis=0) if (sizes < 0.05).any(): raise RuntimeError( - f'Dimensions of the surface {_bem_surf_name[surf["id"]]} seem too ' + f"Dimensions of the surface {_bem_surf_name[surf['id']]} seem too " f"small ({1000 * sizes.min():9.5f}). Maybe the unit of measure" " is meters instead of mm" ) @@ -599,8 +599,7 @@ def _surfaces_to_bem( # surfs can be strings (filenames) or surface dicts if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == len(sigmas)): raise ValueError( - "surfs, ids, and sigmas must all have the same " - "number of elements (1 or 3)" + "surfs, ids, and sigmas must all have the same number of elements (1 or 3)" ) for si, surf in enumerate(surfs): if isinstance(surf, str | Path | os.PathLike): @@ -1260,8 +1259,7 @@ def make_watershed_bem( if op.isdir(ws_dir): if not overwrite: raise RuntimeError( - f"{ws_dir} already exists. Use the --overwrite option" - " to recreate it." + f"{ws_dir} already exists. Use the --overwrite option to recreate it." ) else: shutil.rmtree(ws_dir) @@ -2460,7 +2458,7 @@ def check_seghead(surf_path=subj_path / "surf"): logger.info(f"{ii}. Creating {level} tessellation...") logger.info( f"{ii}.1 Decimating the dense tessellation " - f'({len(surf["tris"])} -> {n_tri} triangles)...' + f"({len(surf['tris'])} -> {n_tri} triangles)..." ) points, tris = decimate_surface( points=surf["rr"], triangles=surf["tris"], n_triangles=n_tri diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8fbff33c13e..bf9e58f2819 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -661,17 +661,21 @@ def _pick_projs(self): return self def add_channels(self, add_list, force_update_info=False): - """Append new channels to the instance. + """Append new channels from other MNE objects to the instance. Parameters ---------- add_list : list - A list of objects to append to self. Must contain all the same - type as the current object. + A list of MNE objects to append to the current instance. + The channels contained in the other instances are appended to the + channels of the current instance. Therefore, all other instances + must be of the same type as the current object. + See notes on how to add data coming from an array. force_update_info : bool If True, force the info for objects to be appended to match the - values in ``self``. This should generally only be used when adding - stim channels for which important metadata won't be overwritten. + values of the current instance. This should generally only be + used when adding stim channels for which important metadata won't + be overwritten. .. versionadded:: 0.12 @@ -688,6 +692,12 @@ def add_channels(self, add_list, force_update_info=False): ----- If ``self`` is a Raw instance that has been preloaded into a :obj:`numpy.memmap` instance, the memmap will be resized. + + This function expects an MNE object to be appended (e.g. :class:`~mne.io.Raw`, + :class:`~mne.Epochs`, :class:`~mne.Evoked`). If you simply want to add a + channel based on values of an np.ndarray, you need to create a + :class:`~mne.io.RawArray`. + See `_ """ # avoid circular imports from ..epochs import BaseEpochs @@ -1372,7 +1382,7 @@ def read_ch_adjacency(fname, picks=None): raise ValueError( f"No built-in channel adjacency matrix found with name: " f"{ch_adj_name}. Valid names are: " - f'{", ".join(get_builtin_ch_adjacencies())}' + f"{', '.join(get_builtin_ch_adjacencies())}" ) ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES if a.name == ch_adj_name][0] diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 87550d66807..15cef38dec7 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -405,7 +405,7 @@ def save(self, fname, *, overwrite=False, verbose=None): Parameters ---------- fname : path-like - The filename to use. Should end in .fif or .fif.gz. + The filename to use. Should end in ``-dig.fif`` or ``-dig.fif.gz``. %(overwrite)s %(verbose)s @@ -1287,7 +1287,7 @@ def _backcompat_value(pos, ref_pos): f"Not setting position{_pl(extra)} of {len(extra)} {types} " f"channel{_pl(extra)} found in montage:\n{names}\n" "Consider setting the channel types to be of " - f'{docdict["montage_types"]} ' + f"{docdict['montage_types']} " "using inst.set_channel_types before calling inst.set_montage, " "or omit these channels when creating your montage." ) diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index f51b551a1c8..bb886c51a96 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -404,8 +404,7 @@ def test_adjacency_matches_ft(tmp_path): if hash_mne.hexdigest() != hash_ft.hexdigest(): raise ValueError( - f"Hash mismatch between built-in and FieldTrip neighbors " - f"for {fname}" + f"Hash mismatch between built-in and FieldTrip neighbors for {fname}" ) diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 8add1398409..d9306b5e1bd 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -420,12 +420,7 @@ def test_documented(): ), pytest.param( partial(read_dig_hpts, unit="m"), - ( - "eeg Fp1 -95.0 -3. -3.\n" - "eeg AF7 -1 -1 -3\n" - "eeg A3 -2 -2 2\n" - "eeg A 0 0 0" - ), + ("eeg Fp1 -95.0 -3. -3.\neeg AF7 -1 -1 -3\neeg A3 -2 -2 2\neeg A 0 0 0"), make_dig_montage( ch_pos={ "A": [0.0, 0.0, 0.0], diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py index 5b7d020b98d..894ede7fa1a 100644 --- a/mne/commands/mne_make_scalp_surfaces.py +++ b/mne/commands/mne_make_scalp_surfaces.py @@ -49,8 +49,7 @@ def run(): "--force", dest="force", action="store_true", - help="Force creation of the surface even if it has " - "some topological defects.", + help="Force creation of the surface even if it has some topological defects.", ) parser.add_option( "-t", diff --git a/mne/commands/mne_setup_source_space.py b/mne/commands/mne_setup_source_space.py index e536a59f90b..273e833b31c 100644 --- a/mne/commands/mne_setup_source_space.py +++ b/mne/commands/mne_setup_source_space.py @@ -62,8 +62,7 @@ def run(): parser.add_option( "--ico", dest="ico", - help="use the recursively subdivided icosahedron " - "to create the source space.", + help="use the recursively subdivided icosahedron to create the source space.", default=None, type="int", ) diff --git a/mne/conftest.py b/mne/conftest.py index fc3bc3b7a53..8a4586067b3 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -6,6 +6,7 @@ import inspect import os import os.path as op +import re import shutil import sys import warnings @@ -79,7 +80,7 @@ collect_ignore = ["export/_brainvision.py", "export/_eeglab.py", "export/_edf.py"] -def pytest_configure(config): +def pytest_configure(config: pytest.Config): """Configure pytest options.""" # Markers for marker in ( @@ -183,6 +184,11 @@ def pytest_configure(config): ignore:The (non_)?interactive_bk attribute was deprecated.*: # SWIG (via OpenMEEG) ignore:.*builtin type swigvarlink has no.*:DeprecationWarning + # eeglabio + ignore:numpy\.core\.records is deprecated.*:DeprecationWarning + ignore:Starting field name with a underscore.*: + # joblib + ignore:process .* is multi-threaded, use of fork/exec.*:DeprecationWarning """ # noqa: E501 for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() @@ -646,6 +652,11 @@ def _check_skip_backend(name): pytest.skip("Test skipped, requires Qt.") else: assert name == "notebook", name + pytest.importorskip("jupyter") + pytest.importorskip("ipympl") + pytest.importorskip("trame") + pytest.importorskip("trame_vtk") + pytest.importorskip("trame_vuetify") if not _notebook_vtk_works(): pytest.skip("Test skipped, requires working notebook vtk") @@ -1174,10 +1185,55 @@ def qt_windows_closed(request): @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_runtest_makereport(item, call): - """Stash the status of each item.""" + """Stash the status of each item and turn unexpected skips into errors.""" outcome = yield - rep = outcome.get_result() + rep: pytest.TestReport = outcome.get_result() item.stash.setdefault(_phase_report_key, {})[rep.when] = rep + _modify_report_skips(rep) + return rep + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_make_collect_report(collector: pytest.Collector): + """Turn unexpected skips during collection (e.g., module-level) into errors.""" + outcome = yield + rep: pytest.CollectReport = outcome.get_result() + _modify_report_skips(rep) + return rep + + +# Default means "allow all skips". Can use something like "$." to mean +# "never match", i.e., "treat all skips as errors" +_valid_skips_re = re.compile(os.getenv("MNE_TEST_ALLOW_SKIP", ".*")) + + +# To turn unexpected skips into errors, we need to look both at the collection phase +# (for decorated tests) and the call phase (for things like `importorskip` +# within the test body). code adapted from pytest-error-for-skips +def _modify_report_skips(report: pytest.TestReport | pytest.CollectReport): + if not report.skipped: + return + if isinstance(report.longrepr, tuple): + file, lineno, reason = report.longrepr + else: + file, lineno, reason = "", 1, str(report.longrepr) + if _valid_skips_re.match(reason): + return + assert isinstance(report, pytest.TestReport | pytest.CollectReport), type(report) + if file.endswith("doctest.py"): # _python/doctest.py + return + # xfail tests aren't true "skips" but show up as skipped in reports + if getattr(report, "keywords", {}).get("xfail", False): + return + # the above only catches marks, so we need to actually parse the report to catch + # an xfail based on the traceback + if " pytest.xfail( " in reason: + return + if reason.startswith("Skipped: "): + reason = reason[9:] + report.longrepr = f"{file}:{lineno}: UNEXPECTED SKIP: {reason}" + # Make it show up as an error in the report + report.outcome = "error" if isinstance(report, pytest.TestReport) else "failed" @pytest.fixture(scope="function") diff --git a/mne/coreg.py b/mne/coreg.py index f28c6142c96..c7549ee028a 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -876,8 +876,7 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) if (subject_from is None) != (scale is None): raise TypeError( - "Need to provide either both subject_from and scale " - "parameters, or neither." + "Need to provide either both subject_from and scale parameters, or neither." ) if subject_from is None: @@ -1402,8 +1401,7 @@ def _read_surface(filename, *, on_defects): complete_surface_info(bem, copy=False) except Exception: raise ValueError( - f"Error loading surface from {filename} (see " - "Terminal for details)." + f"Error loading surface from {filename} (see Terminal for details)." ) return bem @@ -2145,8 +2143,7 @@ def omit_head_shape_points(self, distance): mask = self._orig_hsp_point_distance <= distance n_excluded = np.sum(~mask) logger.info( - "Coregistration: Excluding %i head shape points with " - "distance >= %.3f m.", + "Coregistration: Excluding %i head shape points with distance >= %.3f m.", n_excluded, distance, ) diff --git a/mne/cov.py b/mne/cov.py index 8b86119c1d1..94239472fa2 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1293,7 +1293,7 @@ def _compute_covariance_auto( data_ = data.copy() name = method_.__name__ if callable(method_) else method_ logger.info( - f'Estimating {cov_kind + (" " if cov_kind else "")}' + f"Estimating {cov_kind + (' ' if cov_kind else '')}" f"covariance using {name.upper()}" ) mp = method_params[method_] @@ -1712,7 +1712,7 @@ def _get_ch_whitener(A, pca, ch_type, rank): logger.info( f" Setting small {ch_type} eigenvalues to zero " - f'({"using" if pca else "without"} PCA)' + f"({'using' if pca else 'without'} PCA)" ) if pca: # No PCA case. # This line will reduce the actual number of variables in data @@ -2400,7 +2400,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data = tag.data diag = True logger.info( - " %d x %d diagonal covariance (kind = " "%d) found.", + " %d x %d diagonal covariance (kind = %d) found.", dim, dim, cov_kind, @@ -2416,7 +2416,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data.flat[:: dim + 1] /= 2.0 diag = False logger.info( - " %d x %d full covariance (kind = %d) " "found.", + " %d x %d full covariance (kind = %d) found.", dim, dim, cov_kind, @@ -2425,7 +2425,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): diag = False data = tag.data logger.info( - " %d x %d sparse covariance (kind = %d)" " found.", + " %d x %d sparse covariance (kind = %d) found.", dim, dim, cov_kind, diff --git a/mne/datasets/__init__.pyi b/mne/datasets/__init__.pyi index 44cee84fe7f..2f69a1027e5 100644 --- a/mne/datasets/__init__.pyi +++ b/mne/datasets/__init__.pyi @@ -6,6 +6,7 @@ __all__ = [ "epilepsy_ecog", "erp_core", "eyelink", + "default_path", "fetch_aparc_sub_parcellation", "fetch_dataset", "fetch_fsaverage", @@ -70,6 +71,7 @@ from ._infant import fetch_infant_template from ._phantom.base import fetch_phantom from .utils import ( _download_all_example_data, + default_path, fetch_aparc_sub_parcellation, fetch_hcp_mmp_parcellation, has_dataset, diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 1e38606f908..8f44459ad97 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -143,8 +143,7 @@ def fetch_dataset( if auth is not None: if len(auth) != 2: raise RuntimeError( - "auth should be a 2-tuple consisting " - "of a username and password/token." + "auth should be a 2-tuple consisting of a username and password/token." ) # processor to uncompress files diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ccd4babacd9..75eff184cd1 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -92,8 +92,8 @@ phantom_kit="0.2", ucl_opm_auditory="0.2", ) -TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' -MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' +TESTING_VERSIONED = f"mne-testing-data-{RELEASES['testing']}" +MISC_VERSIONED = f"mne-misc-data-{RELEASES['misc']}" # To update any other dataset besides `testing` or `misc`, upload the new # version of the data archive itself (e.g., to https://osf.io or wherever) and @@ -118,7 +118,7 @@ hash="md5:d94fe9f3abe949a507eaeb865fb84a3f", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" - f'tar.gz/{RELEASES["testing"]}' + f"tar.gz/{RELEASES['testing']}" ), # In case we ever have to resort to osf.io again... # archive_name='mne-testing-data.tar.gz', @@ -131,8 +131,7 @@ archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', hash="md5:e343d3a00cb49f8a2f719d14f4758afe", url=( - "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" - f'{RELEASES["misc"]}' + f"https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/{RELEASES['misc']}" ), folder_name="MNE-misc-data", config_key="MNE_DATASETS_MISC_PATH", diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index b97d0611591..7fbcca3a2d7 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import inspect import os import os.path as op @@ -114,12 +115,16 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): data = data.set_index(("Subject - age - sex", "Nr")) data.index.name = "subject" data.columns.names = [None, None] + kwargs = dict() + # TODO VERSION can be removed once we require Pandas 2.1 + if "future_stack" in inspect.getfullargspec(pd.DataFrame.stack).args: + kwargs["future_stack"] = True data = ( data.set_index( [("Subject - age - sex", "Age"), ("Subject - age - sex", "M1/F2")], append=True, ) - .stack(level=0) + .stack(level=0, **kwargs) .reset_index() ) diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index c14282ed202..b5ea1764946 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -122,10 +122,7 @@ def fetch_data( ) _on_missing(on_missing, msg) if 13 in subjects and 2 in recording: - msg = ( - "Requested recording 2 for subject 13, but it is not available " - "in corpus." - ) + msg = "Requested recording 2 for subject 13, but it is not available in corpus." _on_missing(on_missing, msg) fnames = [] diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 452e42cffc7..93aabc0841a 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import glob import importlib import inspect import logging @@ -92,6 +93,22 @@ def _dataset_version(path, name): return version +@verbose +def default_path(*, verbose=None): + """Get the default MNE_DATA path. + + Parameters + ---------- + %(verbose)s + + Returns + ------- + data_path : instance of Path + Path to the default MNE_DATA directory. + """ + return _get_path(None, None, None) + + def _get_path(path, key, name): """Get a dataset path.""" # 1. Input @@ -113,7 +130,8 @@ def _get_path(path, key, name): return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info(f"Using default location ~/mne_data for {name}...") + extra = f" for {name}" if name else "" + logger.info(f"Using default location ~/mne_data{extra}...") path = Path(os.getenv("_MNE_FAKE_HOME_DIR", "~")).expanduser() / "mne_data" if not path.is_dir(): logger.info(f"Creating {path}") @@ -319,6 +337,8 @@ def _download_all_example_data(verbose=True): # # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build + import openneuro + paths = dict() for kind in ( "sample testing misc spm_face somato hf_sef multimodal " @@ -375,6 +395,14 @@ def _download_all_example_data(verbose=True): limo.load_data(subject=1, update_path=True) logger.info("[done limo]") + # for ESG + ds = "ds004388" + target_dir = default_path() / ds + run_name = "sub-001/eeg/*median_run-03_eeg*.set" + if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) + @verbose def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 85ed102b514..a291416bb17 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -85,7 +85,11 @@ def __sklearn_tags__(self): """Get sklearn tags.""" from sklearn.utils import get_tags # added in 1.6 - return get_tags(self.model) + # fit method below does not allow sparse data via check_data, we could + # eventually make it smarter if we had to + tags = get_tags(self.model) + tags.input_tags.sparse = False + return tags def __getattr__(self, attr): """Wrap to model for some attributes.""" diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 198feeb6532..8f4d2472803 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -474,5 +474,5 @@ def test_non_full_rank_data(): ssd = SSD(info, filt_params_signal, filt_params_noise) if sys.platform == "darwin": - pytest.skip("Unknown linalg bug (Accelerate?)") + pytest.xfail("Unknown linalg bug (Accelerate?)") ssd.fit(X) diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 8eb2dcc5510..e475cd22161 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -238,7 +238,7 @@ def inverse_transform(self, epochs_data): return out -class Vectorizer(TransformerMixin): +class Vectorizer(TransformerMixin, BaseEstimator): """Transform n-dimensional array into 2D array of n_samples by n_features. This class reshapes an n-dimensional array into an n_samples * n_features @@ -343,7 +343,7 @@ def inverse_transform(self, X): @fill_doc -class PSDEstimator(TransformerMixin): +class PSDEstimator(TransformerMixin, BaseEstimator): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -452,7 +452,7 @@ def transform(self, epochs_data): @fill_doc -class FilterEstimator(TransformerMixin): +class FilterEstimator(TransformerMixin, BaseEstimator): """Estimator to filter RtEpochs. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -743,7 +743,7 @@ def _apply_method(self, X, method): @fill_doc -class TemporalFilter(TransformerMixin): +class TemporalFilter(TransformerMixin, BaseEstimator): """Estimator to filter data array along the last dimension. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop diff --git a/mne/epochs.py b/mne/epochs.py index 04b1a288bfe..679643ab969 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1671,8 +1671,7 @@ def _get_data( # we start out with an empty array, allocate only if necessary data = np.empty((0, len(self.info["ch_names"]), len(self.times))) msg = ( - f"for {n_events} events and {len(self._raw_times)} " - "original time points" + f"for {n_events} events and {len(self._raw_times)} original time points" ) if self._decim > 1: msg += " (prior to decimation)" @@ -2301,8 +2300,7 @@ def save( logger.info(f"Splitting into {n_parts} parts") if n_parts > 100: # This must be an error raise ValueError( - f"Split size {split_size} would result in writing " - f"{n_parts} files" + f"Split size {split_size} would result in writing {n_parts} files" ) if len(self.drop_log) > 100000: @@ -3143,7 +3141,7 @@ def _ensure_list(x): raise ValueError( f"The event names in keep_first and keep_last must " f"be mutually exclusive. Specified in both: " - f'{", ".join(sorted(keep_first_and_last))}' + f"{', '.join(sorted(keep_first_and_last))}" ) del keep_first_and_last @@ -3163,7 +3161,7 @@ def _diff_input_strings_vs_event_id(input_strings, input_name, event_id): if event_name_diff: raise ValueError( f"Present in {input_name}, but missing from event_id: " - f'{", ".join(event_name_diff)}' + f"{', '.join(event_name_diff)}" ) _diff_input_strings_vs_event_id( @@ -3556,8 +3554,7 @@ def __init__( if not isinstance(raw, BaseRaw): raise ValueError( - "The first argument to `Epochs` must be an " - "instance of mne.io.BaseRaw" + "The first argument to `Epochs` must be an instance of mne.io.BaseRaw" ) info = deepcopy(raw.info) annotations = raw.annotations.copy() @@ -4441,8 +4438,7 @@ def _get_epoch_from_raw(self, idx, verbose=None): else: # read the correct subset of the data raise RuntimeError( - "Correct epoch could not be found, please " - "contact mne-python developers" + "Correct epoch could not be found, please contact mne-python developers" ) # the following is equivalent to this, but faster: # diff --git a/mne/event.py b/mne/event.py index 723615ea56a..a19270db1e6 100644 --- a/mne/event.py +++ b/mne/event.py @@ -1649,7 +1649,7 @@ def match_event_names(event_names, keys, *, on_missing="raise"): _on_missing( on_missing=on_missing, msg=f'Event name "{key}" could not be found. The following events ' - f'are present in the data: {", ".join(event_names)}', + f"are present in the data: {', '.join(event_names)}", error_klass=KeyError, ) diff --git a/mne/evoked.py b/mne/evoked.py index 5fb09db9d1b..c04f83531e3 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -962,7 +962,7 @@ def __neg__(self): if out.comment is not None and " + " in out.comment: out.comment = f"({out.comment})" # multiple conditions in evoked - out.comment = f'- {out.comment or "unknown"}' + out.comment = f"- {out.comment or 'unknown'}" return out def get_peak( @@ -1053,8 +1053,7 @@ def get_peak( raise ValueError('Channel type must be "grad" for merge_grads') elif mode == "neg": raise ValueError( - "Negative mode (mode=neg) does not make " - "sense with merge_grads=True" + "Negative mode (mode=neg) does not make sense with merge_grads=True" ) meg = eeg = misc = seeg = dbs = ecog = fnirs = False @@ -1650,12 +1649,12 @@ def combine_evoked(all_evoked, weights): if e.comment is not None and " + " in e.comment: # multiple conditions this_comment = f"({e.comment})" else: - this_comment = f'{e.comment or "unknown"}' + this_comment = f"{e.comment or 'unknown'}" # assemble everything if idx == 0: comment += f"{sign}{weight}{multiplier}{this_comment}" else: - comment += f' {sign or "+"} {weight}{multiplier}{this_comment}' + comment += f" {sign or '+'} {weight}{multiplier}{this_comment}" # special-case: combine_evoked([e1, -e2], [1, -1]) evoked.comment = comment.replace(" - - ", " + ") return evoked @@ -1872,8 +1871,7 @@ def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): if len(chs) != nchan: raise ValueError( - "Number of channels and number of " - "channel definitions are different" + "Number of channels and number of channel definitions are different" ) ch_names_mapping = _read_extended_ch_info(chs, my_evoked, fid) diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index ba64ba010ce..6503c540f41 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -107,6 +107,13 @@ def _export_mne_raw(*, raw, fname, events=None, overwrite=False): def _mne_annots2pybv_events(raw): """Convert mne Annotations to pybv events.""" + # check that raw.annotations.orig_time is the same as raw.info["meas_date"] + # so that onsets are relative to the first sample + # (after further correction for first_time) + if raw.annotations and raw.info["meas_date"] != raw.annotations.orig_time: + raise ValueError( + "Annotations must have the same orig_time as raw.info['meas_date']" + ) events = [] for annot in raw.annotations: # handle onset and duration: seconds to sample, relative to diff --git a/mne/export/_edf.py b/mne/export/_edf.py index ef870692014..e50b05f7056 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -7,6 +7,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_edfio_installed, warn _check_edfio_installed() @@ -204,7 +205,9 @@ def _export_raw(fname, raw, physical_range, add_ch_type): for desc, onset, duration, ch_names in zip( raw.annotations.description, - raw.annotations.onset, + # subtract raw.first_time because EDF marks events starting from the first + # available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), raw.annotations.duration, raw.annotations.ch_names, ): diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 3c8f896164a..459207f0616 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -4,6 +4,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_eeglabio_installed _check_eeglabio_installed() @@ -24,11 +25,16 @@ def _export_raw(fname, raw): ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [ - raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration, - ] + if raw.annotations: + annotations = [ + raw.annotations.description, + # subtract raw.first_time because EEGLAB marks events starting from + # the first available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), + raw.annotations.duration, + ] + else: + annotations = None eeglabio.raw.export_set( fname, data=raw.get_data(picks=ch_names), diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index 3792ea4a6a5..185afb5f558 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -53,7 +53,7 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose= info = evoked[0].info if np.round(info["sfreq"]) != info["sfreq"]: raise ValueError( - f'Sampling frequency must be a whole number. sfreq: {info["sfreq"]}' + f"Sampling frequency must be a whole number. sfreq: {info['sfreq']}" ) sampling_rate = int(info["sfreq"]) diff --git a/mne/export/_export.py b/mne/export/_export.py index 490bf986895..4b93fda917e 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -25,6 +25,14 @@ def export_raw( %(export_warning)s + .. warning:: + When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the + same as ``raw.annotations.orig_time``. This guarantees that the annotations are + in the same reference frame as the samples. When + :attr:`Raw.first_time ` is not zero (e.g., after + cropping), the onsets are automatically corrected so that onsets are always + relative to the first sample. + Parameters ---------- %(fname_export_params)s @@ -216,7 +224,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): supported_str = ", ".join(supported) raise ValueError( - f"Format '{fmt}' is not supported. " - f"Supported formats are {supported_str}." + f"Format '{fmt}' is not supported. Supported formats are {supported_str}." ) return fmt diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 706a83476e4..6f712923c7d 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -122,6 +122,49 @@ def test_export_raw_eeglab(tmp_path): raw.export(temp_fname, overwrite=True) +@pytest.mark.parametrize("tmin", (0, 1, 5, 10)) +def test_export_raw_eeglab_annotations(tmp_path, tmin): + """Test annotations in the exported EEGLAB file. + + All annotations should be preserved and onset corrected. + """ + pytest.importorskip("eeglabio") + raw = read_raw_fif(fname_raw, preload=True) + raw.apply_proj() + annotations = Annotations( + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ch_names=[["MEG 0113"], ["MEG 0113", "MEG 0132"], [], ["MEG 0143"]], + ) + raw.set_annotations(annotations) + raw.crop(tmin) + + # export + temp_fname = tmp_path / "test.set" + raw.export(temp_fname) + + # read in the file + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") + assert raw_read.first_time == 0 # exportation resets first_time + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, + raw_read.annotations.onset, + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + + def _create_raw_for_edf_tests(stim_channel_index=None): rng = np.random.RandomState(12345) ch_types = [ @@ -145,7 +188,7 @@ def _create_raw_for_edf_tests(stim_channel_index=None): edfio_mark = pytest.mark.skipif( - not _check_edfio_installed(strict=False), reason="unsafe use of private module" + not _check_edfio_installed(strict=False), reason="requires edfio" ) @@ -154,6 +197,7 @@ def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date raw.info["subject_info"] = dict( @@ -235,7 +279,7 @@ def test_edf_padding(tmp_path, pad_width): RuntimeWarning, match=( "EDF format requires equal-length data blocks.*" - f"{pad_width/1000:.3g} seconds of edge values were appended.*" + f"{pad_width / 1000:.3g} seconds of edge values were appended.*" ), ): raw.export(temp_fname) @@ -258,8 +302,12 @@ def test_edf_padding(tmp_path, pad_width): @edfio_mark() -def test_export_edf_annotations(tmp_path): - """Test that exporting EDF preserves annotations.""" +@pytest.mark.parametrize("tmin", (0, 0.005, 0.03, 1)) +def test_export_edf_annotations(tmp_path, tmin): + """Test annotations in the exported EDF file. + + All annotations should be preserved and onset corrected. + """ raw = _create_raw_for_edf_tests() annotations = Annotations( onset=[0.01, 0.05, 0.90, 1.05], @@ -268,17 +316,44 @@ def test_export_edf_annotations(tmp_path): ch_names=[["0"], ["0", "1"], [], ["1"]], ) raw.set_annotations(annotations) + raw.crop(tmin) + assert raw.first_time == tmin + + if raw.n_times % raw.info["sfreq"] == 0: + expectation = nullcontext() + else: + expectation = pytest.warns( + RuntimeWarning, match="EDF format requires equal-length data blocks" + ) # export temp_fname = tmp_path / "test.edf" - raw.export(temp_fname) + with expectation: + raw.export(temp_fname) # read in the file raw_read = read_raw_edf(temp_fname, preload=True) - assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) - assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, raw_read.annotations.description) - assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) + assert raw_read.first_time == 0 # exportation resets first_time + bad_annot = raw_read.annotations.description == "BAD_ACQ_SKIP" + if bad_annot.any(): + raw_read.annotations.delete(bad_annot) + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, raw_read.annotations.onset + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + assert_array_equal( + raw.annotations.ch_names[valid_annot], raw_read.annotations.ch_names + ) @edfio_mark() @@ -476,7 +551,7 @@ def test_export_epochs_eeglab(tmp_path, preload): with ctx(): epochs.export(temp_fname) epochs.drop_channels([ch for ch in ["epoc", "STI 014"] if ch in epochs.ch_names]) - epochs_read = read_epochs_eeglab(temp_fname) + epochs_read = read_epochs_eeglab(temp_fname, verbose="error") # head radius assert epochs.ch_names == epochs_read.ch_names cart_coords = np.array([d["loc"][:3] for d in epochs.info["chs"]]) # just xyz cart_coords_read = np.array([d["loc"][:3] for d in epochs_read.info["chs"]]) @@ -580,7 +655,7 @@ def test_export_to_mff_incompatible_sfreq(): """Test non-whole number sampling frequency throws ValueError.""" pytest.importorskip("mffpy", "0.5.7") evoked = read_evokeds(fname_evoked) - with pytest.raises(ValueError, match=f'sfreq: {evoked[0].info["sfreq"]}'): + with pytest.raises(ValueError, match=f"sfreq: {evoked[0].info['sfreq']}"): export_evokeds("output.mff", evoked) diff --git a/mne/filter.py b/mne/filter.py index ee5b34cd657..a7d7c883e2f 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -411,8 +411,7 @@ def _prep_for_filtering(x, copy, picks=None): picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: raise ValueError( - "picks argument is not supported for data with more" - " than three dimensions" + "picks argument is not supported for data with more than three dimensions" ) assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above @@ -2873,7 +2872,7 @@ def design_mne_c_filter( h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq)) logger.info( - "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d " "hpw : %d lpw : %d", + "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d hpw : %d lpw : %d", l_freq, h_freq, l_start, diff --git a/mne/fixes.py b/mne/fixes.py index 2aed20492ec..070d4125d18 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -720,3 +720,16 @@ def minimum_phase(h, method="homomorphic", n_fft=None, *, half=True): n_out = (n_half + len(h) % 2) if half else len(h) return h_minimum[:n_out] + + +# SciPy 1.15 deprecates sph_harm for sph_harm_y and using it will trigger a +# DeprecationWarning. This is a backport of the new function for older SciPy versions. +def sph_harm_y(n, m, theta, phi, *, diff_n=0): + """Wrap scipy.special.sph_harm for sph_harm_y.""" + # Can be removed once we no longer support scipy < 1.15.0 + from scipy import special + + if "sph_harm_y" in special.__dict__: + return special.sph_harm_y(n, m, theta, phi, diff_n=diff_n) + else: + return special.sph_harm(m, n, phi, theta) diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index b505b5e45df..e98a147b560 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -96,7 +96,7 @@ def _pinv_trunc(x, miss): varexp /= varexp[-1] n = np.where(varexp >= (1.0 - miss))[0][0] + 1 logger.info( - " Truncating at %d/%d components to omit less than %g " "(%0.2g)", + " Truncating at %d/%d components to omit less than %g (%0.2g)", n, len(s), miss, @@ -111,8 +111,7 @@ def _pinv_tikhonov(x, reg): # _reg_pinv requires square Hermitian, which we have here inv, _, n = _reg_pinv(x, reg=reg, rank=None) logger.info( - f" Truncating at {n}/{len(x)} components and regularizing " - f"with α={reg:0.1e}" + f" Truncating at {n}/{len(x)} components and regularizing with α={reg:0.1e}" ) return inv, n diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 64aadf69fec..6c77f47e312 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -160,8 +160,7 @@ def _create_meg_coil(coilset, ch, acc, do_es): break else: raise RuntimeError( - "Desired coil definition not found " - f"(type = {ch['coil_type']} acc = {acc})" + f"Desired coil definition not found (type = {ch['coil_type']} acc = {acc})" ) # Apply a coordinate transformation if so desired @@ -295,8 +294,8 @@ def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, verbose=None) else: if bem["surfs"][0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise RuntimeError( - f'BEM is in {_coord_frame_name(bem["surfs"][0]["coord_frame"])} ' - 'coordinates, should be in MRI' + f"BEM is in {_coord_frame_name(bem['surfs'][0]['coord_frame'])} " + "coordinates, should be in MRI" ) if neeg > 0 and len(bem["surfs"]) == 1: raise RuntimeError( @@ -335,7 +334,7 @@ def _prep_meg_channels( del picks # Get channel info and names for MEG channels - logger.info(f'Read {len(info_meg["chs"])} MEG channels from info') + logger.info(f"Read {len(info_meg['chs'])} MEG channels from info") # Get MEG compensation channels compensator = post_picks = None @@ -352,7 +351,7 @@ def _prep_meg_channels( 'channels. Consider using "ignore_ref=True" in ' "calculation" ) - logger.info(f'{len(info["comps"])} compensation data sets in info') + logger.info(f"{len(info['comps'])} compensation data sets in info") # Compose a compensation data set if necessary # adapted from mne_make_ctf_comp() from mne_ctf_comp.c logger.info("Setting up compensation data...") diff --git a/mne/forward/forward.py b/mne/forward/forward.py index e3e5c08d2f8..f1c2c2d11d7 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -512,7 +512,7 @@ def _merge_fwds(fwds, *, verbose=None): a[k]["row_names"] = a[k]["row_names"] + b[k]["row_names"] a["nchan"] = a["nchan"] + b["nchan"] if len(fwds) > 1: - logger.info(f' Forward solutions combined: {", ".join(combined)}') + logger.info(f" Forward solutions combined: {', '.join(combined)}") return fwd @@ -677,8 +677,7 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=True, verbos # Make sure forward solution is in either the MRI or HEAD coordinate frame if fwd["coord_frame"] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): raise ValueError( - "Only forward solutions computed in MRI or head " - "coordinates are acceptable" + "Only forward solutions computed in MRI or head coordinates are acceptable" ) # Transform each source space to the HEAD or MRI coordinate frame, @@ -1205,8 +1204,7 @@ def _triage_loose(src, loose, fixed="auto"): if fixed is True: if not all(v == 0.0 for v in loose.values()): raise ValueError( - 'When using fixed=True, loose must be 0. or "auto", ' - f"got {orig_loose}" + f'When using fixed=True, loose must be 0. or "auto", got {orig_loose}' ) elif fixed is False: if any(v == 0.0 for v in loose.values()): @@ -1666,8 +1664,7 @@ def apply_forward( for ch_name in fwd["sol"]["row_names"]: if ch_name not in info["ch_names"]: raise ValueError( - f"Channel {ch_name} of forward operator not present in " - "evoked_template." + f"Channel {ch_name} of forward operator not present in evoked_template." ) # project the source estimate to the sensor space diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index 37ec6e041b5..a357c5779c9 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -482,7 +482,7 @@ def test_make_forward_solution_openmeeg(n_layers): eeg_atol=100, meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, + meg_rdm_tol=0.11, eeg_rdm_tol=0.2, ) diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 98e3fbfc0b3..b365a2eed5a 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -1611,8 +1611,7 @@ def _configure_dock(self): func=self._set_subjects_dir, is_directory=True, icon=True, - tooltip="Load the path to the directory containing the " - "FreeSurfer subjects", + tooltip="Load the path to the directory containing the FreeSurfer subjects", layout=subjects_dir_layout, ) self._renderer._layout_add_widget( @@ -1741,8 +1740,7 @@ def _configure_dock(self): self._widgets["omit"] = self._renderer._dock_add_button( name="Omit", callback=self._omit_hsp, - tooltip="Exclude the head shape points that are far away from " - "the MRI head", + tooltip="Exclude the head shape points that are far away from the MRI head", layout=omit_hsp_layout_2, ) self._widgets["reset_omit"] = self._renderer._dock_add_button( diff --git a/mne/html_templates/_templates.py b/mne/html_templates/_templates.py index 9427f2d6a25..1f68303a51e 100644 --- a/mne/html_templates/_templates.py +++ b/mne/html_templates/_templates.py @@ -66,7 +66,7 @@ def _format_time_range(inst) -> str: def _format_projs(info) -> list[str]: """Format projectors.""" - projs = [f'{p["desc"]} ({"on" if p["active"] else "off"})' for p in info["projs"]] + projs = [f"{p['desc']} ({'on' if p['active'] else 'off'})" for p in info["projs"]] return projs diff --git a/mne/io/array/__init__.py b/mne/io/array/__init__.py index aea21ef42ce..ad53f7c817f 100644 --- a/mne/io/array/__init__.py +++ b/mne/io/array/__init__.py @@ -4,4 +4,4 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from .array import RawArray +from ._array import RawArray diff --git a/mne/io/array/array.py b/mne/io/array/_array.py similarity index 100% rename from mne/io/array/array.py rename to mne/io/array/_array.py diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index 039108eb915..610f32ba5da 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -35,9 +35,9 @@ def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0): angle = np.rad2deg(_angle_between_quats(quat_est, quat)) dist = np.linalg.norm(trans - trans_est) - assert ( - dist <= dist_tol - ), f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + assert dist <= dist_tol, ( + f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + ) assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation" diff --git a/mne/io/base.py b/mne/io/base.py index 4f5f2436bd7..280330367f7 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1013,8 +1013,7 @@ def get_data( if n_rejected > 0: if reject_by_annotation == "omit": msg = ( - "Omitting {} of {} ({:.2%}) samples, retaining {}" - " ({:.2%}) samples." + "Omitting {} of {} ({:.2%}) samples, retaining {} ({:.2%}) samples." ) logger.info( msg.format( @@ -2157,7 +2156,7 @@ def append(self, raws, preload=None): for edge_samp in edge_samps: onset = _sync_onset(self, edge_samp / self.info["sfreq"], True) logger.debug( - f"Marking edge at {edge_samp} samples " f"(maps to {onset:0.3f} sec)" + f"Marking edge at {edge_samp} samples (maps to {onset:0.3f} sec)" ) self.annotations.append(onset, 0.0, "BAD boundary") self.annotations.append(onset, 0.0, "EDGE boundary") diff --git a/mne/io/ctf/ctf.py b/mne/io/ctf/ctf.py index 44a4e39adf6..971ac51c2f6 100644 --- a/mne/io/ctf/ctf.py +++ b/mne/io/ctf/ctf.py @@ -267,7 +267,7 @@ def _get_sample_info(fname, res4, system_clock): fid.seek(offset, 0) this_data = np.fromfile(fid, ">i4", res4["nsamp"]) if len(this_data) != res4["nsamp"]: - raise RuntimeError(f"Cannot read data for trial {t+1}.") + raise RuntimeError(f"Cannot read data for trial {t + 1}.") end = np.where(this_data == 0)[0] if len(end) > 0: n_samp = samp_offset + end[0] diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index 1b96d8bd88f..685a20792d3 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -50,8 +50,7 @@ def _pick_isotrak_and_hpi_coils(res4, coils, t): if p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: if t is None or t["t_ctf_dev_dev"] is None: raise RuntimeError( - "No coordinate transformation " - "available for HPI coil locations" + "No coordinate transformation available for HPI coil locations" ) d = dict( kind=kind, diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py index 4a5dd846655..448ea90baba 100644 --- a/mne/io/ctf/tests/test_ctf.py +++ b/mne/io/ctf/tests/test_ctf.py @@ -243,9 +243,9 @@ def test_read_ctf(tmp_path): # Make sure all digitization points are in the MNE head coord frame for p in raw.info["dig"]: - assert ( - p["coord_frame"] == FIFF.FIFFV_COORD_HEAD - ), "dig points must be in FIFF.FIFFV_COORD_HEAD" + assert p["coord_frame"] == FIFF.FIFFV_COORD_HEAD, ( + "dig points must be in FIFF.FIFFV_COORD_HEAD" + ) if fname.endswith("catch-alp-good-f.ds"): # omit points from .pos file with raw.info._unlock(): diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index bb79c46f24a..09ac24f753e 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -436,21 +436,24 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, ones[orig_idx, smp_read : smp_read + len(one_i)] = one_i n_smp_read[orig_idx] += len(one_i) + # resample channels with lower sample frequency # skip if no data was requested, ie. only annotations were read - if sum(n_smp_read) > 0: + if any(n_smp_read) > 0: # expected number of samples, equals maximum sfreq smp_exp = data.shape[-1] - assert max(n_smp_read) == smp_exp # resample data after loading all chunks to prevent edge artifacts resampled = False + for i, smp_read in enumerate(n_smp_read): # nothing read, nothing to resample if smp_read == 0: continue # upsample if n_samples is lower than from highest sfreq if smp_read != smp_exp: - assert (ones[i, smp_read:] == 0).all() # sanity check + # sanity check that we read exactly how much we expected + assert (ones[i, smp_read:] == 0).all() + ones[i, :] = resample( ones[i, :smp_read].astype(np.float64), smp_exp, @@ -628,7 +631,7 @@ def _get_info( if len(chs_without_types): msg = ( "Could not determine channel type of the following channels, " - f'they will be set as EEG:\n{", ".join(chs_without_types)}' + f"they will be set as EEG:\n{', '.join(chs_without_types)}" ) logger.info(msg) @@ -712,8 +715,8 @@ def _get_info( if info["highpass"] > info["lowpass"]: warn( - f'Highpass cutoff frequency {info["highpass"]} is greater ' - f'than lowpass cutoff frequency {info["lowpass"]}, ' + f"Highpass cutoff frequency {info['highpass']} is greater " + f"than lowpass cutoff frequency {info['lowpass']}, " "setting values to 0 and Nyquist." ) info["highpass"] = 0.0 diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index b4f0ab33fa5..ce671ca7e81 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -259,6 +259,24 @@ def test_edf_different_sfreqs(stim_channel): assert_allclose(times1, times2) +@testing.requires_testing_data +@pytest.mark.parametrize("stim_channel", (None, False, "auto")) +def test_edf_different_sfreqs_nopreload(stim_channel): + """Test loading smaller sfreq channels without preloading.""" + # load without preloading, then load a channel that has smaller sfreq + # as other channels, produced an error, see mne-python/issues/12897 + + for i in range(1, 13): + raw = read_raw_edf(input_fname=edf_reduced, verbose="error", preload=False) + + # this should work for channels of all sfreq, even if larger sfreqs + # are present in the file + x1 = raw.get_data(picks=[f"A{i}"], return_times=False) + # load next ch, this is sometimes with a higher sometimes a lower sfreq + x2 = raw.get_data([f"A{i + 1}"], return_times=False) + assert x1.shape == x2.shape + + def test_edf_data_broken(tmp_path): """Test edf files.""" raw = _test_raw_reader( diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index b2f08020e15..c3a10fb72cd 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -106,7 +106,7 @@ def _read_mff_header(filepath): if bad: raise RuntimeError( "EGI epoch first/last samps could not be parsed:\n" - f'{list(epochs["first_samps"])}\n{list(epochs["last_samps"])}' + f"{list(epochs['first_samps'])}\n{list(epochs['last_samps'])}" ) summaryinfo.update(epochs) # index which samples in raw are actually readable from disk (i.e., not diff --git a/mne/io/fieldtrip/fieldtrip.py b/mne/io/fieldtrip/fieldtrip.py index 5d94d3e0a80..c8521722003 100644 --- a/mne/io/fieldtrip/fieldtrip.py +++ b/mne/io/fieldtrip/fieldtrip.py @@ -7,7 +7,7 @@ from ...epochs import EpochsArray from ...evoked import EvokedArray from ...utils import _check_fname, _import_pymatreader_funcs -from ..array.array import RawArray +from ..array._array import RawArray from .utils import ( _create_event_metadata, _create_events, diff --git a/mne/io/fil/tests/test_fil.py b/mne/io/fil/tests/test_fil.py index 06d3d924319..df15dd13353 100644 --- a/mne/io/fil/tests/test_fil.py +++ b/mne/io/fil/tests/test_fil.py @@ -87,9 +87,9 @@ def _fil_megmag(raw_test, raw_mat): mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) - assert len(mat_inds) == len( - test_inds - ), "Number of magnetometer channels in RAW does not match .mat file!" + assert len(mat_inds) == len(test_inds), ( + "Number of magnetometer channels in RAW does not match .mat file!" + ) a = raw_test._data[test_inds, :] b = raw_mat["trial"][mat_inds, :] * 1e-15 # fT to T @@ -106,9 +106,9 @@ def _fil_stim(raw_test, raw_mat): mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) - assert len(mat_inds) == len( - test_inds - ), "Number of stim channels in RAW does not match .mat file!" + assert len(mat_inds) == len(test_inds), ( + "Number of stim channels in RAW does not match .mat file!" + ) a = raw_test._data[test_inds, :] b = raw_mat["trial"][mat_inds, :] # fT to T @@ -122,9 +122,9 @@ def _fil_sensorpos(raw_test, raw_mat): grad_list = raw_mat["coil_label"] grad_inds = _match_str(test_list, grad_list) - assert len(grad_inds) == len( - test_inds - ), "Number of channels with position data in RAW does not match .mat file!" + assert len(grad_inds) == len(test_inds), ( + "Number of channels with position data in RAW does not match .mat file!" + ) mat_pos = raw_mat["coil_pos"][grad_inds, :] mat_ori = raw_mat["coil_ori"][grad_inds, :] diff --git a/mne/io/neuralynx/tests/test_neuralynx.py b/mne/io/neuralynx/tests/test_neuralynx.py index ea5cdbccdfb..18578ef4ab7 100644 --- a/mne/io/neuralynx/tests/test_neuralynx.py +++ b/mne/io/neuralynx/tests/test_neuralynx.py @@ -143,9 +143,9 @@ def test_neuralynx(): assert raw.info["meas_date"] == meas_date_utc, "meas_date not set correctly" # test that channel selection worked - assert ( - raw.ch_names == expected_chan_names - ), "labels in raw.ch_names don't match expected channel names" + assert raw.ch_names == expected_chan_names, ( + "labels in raw.ch_names don't match expected channel names" + ) mne_y = raw.get_data() # in V @@ -216,9 +216,9 @@ def test_neuralynx_gaps(): n_expected_gaps = 3 n_expected_missing_samples = 130 assert len(raw.annotations) == n_expected_gaps, "Wrong number of gaps detected" - assert ( - (mne_y[0, :] == 0).sum() == n_expected_missing_samples - ), "Number of true and inferred missing samples differ" + assert (mne_y[0, :] == 0).sum() == n_expected_missing_samples, ( + "Number of true and inferred missing samples differ" + ) # read in .mat files containing original gaps matchans = ["LAHC1_3_gaps.mat", "LAHC2_3_gaps.mat"] diff --git a/mne/io/nirx/nirx.py b/mne/io/nirx/nirx.py index 53a812e7a21..5d9b79b57cc 100644 --- a/mne/io/nirx/nirx.py +++ b/mne/io/nirx/nirx.py @@ -210,7 +210,7 @@ def __init__(self, fname, saturated, *, preload=False, encoding=None, verbose=No ): warn( "Only import of data from NIRScout devices have been " - f'thoroughly tested. You are using a {hdr["GeneralInfo"]["Device"]}' + f"thoroughly tested. You are using a {hdr['GeneralInfo']['Device']}" " device." ) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index b559ce07068..8f773533ae4 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -533,7 +533,7 @@ def _test_raw_crop(reader, t_prop, kwargs): n_samp = 50 # crop to this number of samples (per instance) crop_t = n_samp / raw_1.info["sfreq"] t_start = t_prop * crop_t # also crop to some fraction into the first inst - extra = f' t_start={t_start}, preload={kwargs.get("preload", False)}' + extra = f" t_start={t_start}, preload={kwargs.get('preload', False)}" stop = (n_samp - 1) / raw_1.info["sfreq"] raw_1.crop(0, stop) assert len(raw_1.times) == 50 diff --git a/mne/label.py b/mne/label.py index f68144106c3..02bf9dc09c0 100644 --- a/mne/label.py +++ b/mne/label.py @@ -264,8 +264,7 @@ def __init__( if not (len(vertices) == len(values) == len(pos)): raise ValueError( - "vertices, values and pos need to have same " - "length (number of vertices)" + "vertices, values and pos need to have same length (number of vertices)" ) # name @@ -416,7 +415,7 @@ def __sub__(self, other): else: keep = np.arange(len(self.vertices)) - name = f'{self.name or "unnamed"} - {other.name or "unnamed"}' + name = f"{self.name or 'unnamed'} - {other.name or 'unnamed'}" return Label( self.vertices[keep], self.pos[keep], @@ -976,8 +975,7 @@ def _get_label_src(label, src): src = _ensure_src(src) if src.kind != "surface": raise RuntimeError( - "Cannot operate on SourceSpaces that are not " - f"surface type, got {src.kind}" + f"Cannot operate on SourceSpaces that are not surface type, got {src.kind}" ) if label.hemi == "lh": hemi_src = src[0] @@ -1585,8 +1583,7 @@ def stc_to_label( vertno = np.where(src[hemi_idx]["inuse"])[0] if not len(np.setdiff1d(this_vertno, vertno)) == 0: raise RuntimeError( - "stc contains vertices not present " - "in source space, did you morph?" + "stc contains vertices not present in source space, did you morph?" ) tmp = np.zeros((len(vertno), this_data.shape[1])) this_vertno_idx = np.searchsorted(vertno, this_vertno) @@ -2151,8 +2148,7 @@ def _read_annot(fname): cands = _read_annot_cands(dir_name) if len(cands) == 0: raise OSError( - f"No such file {fname}, no candidate parcellations " - "found in directory" + f"No such file {fname}, no candidate parcellations found in directory" ) else: raise OSError( diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index e5129a4822f..7c789503ac1 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -673,7 +673,7 @@ def prepare_inverse_operator( inv["eigen_leads"]["data"] = sqrt(scale) * inv["eigen_leads"]["data"] logger.info( - " Scaled noise and source covariance from nave = %d to" " nave = %d", + " Scaled noise and source covariance from nave = %d to nave = %d", inv["nave"], nave, ) @@ -2011,7 +2011,7 @@ def make_inverse_operator( logger.info( f" scaling factor to adjust the trace = {trace_GRGT:g} " f"(nchan = {eigen_fields.shape[0]} " - f'nzero = {(noise_cov["eig"] <= 0).sum()})' + f"nzero = {(noise_cov['eig'] <= 0).sum()})" ) # MNE-ify everything for output diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index aa3f8294027..5b5c941a9ac 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -130,8 +130,7 @@ def _compare(a, b): for k, v in a.items(): if k not in b and k not in skip_types: raise ValueError( - "First one had one second one didn't:\n" - f"{k} not in {b.keys()}" + f"First one had one second one didn't:\n{k} not in {b.keys()}" ) if k not in skip_types: last_keys.pop() diff --git a/mne/morph.py b/mne/morph.py index 9c475bff1e9..a8278731f3c 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -200,8 +200,7 @@ def compute_source_morph( if kind not in "surface" and xhemi: raise ValueError( - "Inter-hemispheric morphing can only be used " - "with surface source estimates." + "Inter-hemispheric morphing can only be used with surface source estimates." ) if sparse and kind != "surface": raise ValueError("Only surface source estimates can compute a sparse morph.") @@ -1301,8 +1300,7 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=None, verbose=No if isinstance(grade, list): if not len(grade) == 2: raise ValueError( - "grade as a list must have two elements " - "(arrays of output vertices)" + "grade as a list must have two elements (arrays of output vertices)" ) vertices = grade else: @@ -1385,8 +1383,7 @@ def _surf_upsampling_mat(idx_from, e, smooth): smooth = _ensure_int(smooth, "smoothing steps") if smooth <= 0: # == 0 is handled in a shortcut above raise ValueError( - "The number of smoothing operations has to be at least 0, got " - f"{smooth}" + f"The number of smoothing operations has to be at least 0, got {smooth}" ) smooth = smooth - 1 # idx will gradually expand from idx_from -> np.arange(n_tot) diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..c54685dba34 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -56,6 +57,7 @@ from ._fine_cal import ( write_fine_calibration, ) from ._lof import find_bad_channels_lof +from ._pca_obs import apply_pca_obs from ._peak_finder import peak_finder from ._regress import EOGRegression, read_eog_regression, regress_artifact from .artifact_detection import ( diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index 41d20539ce0..06041cd7f8e 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -156,11 +156,12 @@ def compute_fine_calibration( # 1. Rotate surface normals using magnetometer information (if present) # cals = np.ones(len(info["ch_names"])) - time_idxs = raw.time_as_index(np.arange(0.0, raw.times[-1], t_window)) - if len(time_idxs) <= 1: - time_idxs = np.array([0, len(raw.times)], int) - else: - time_idxs[-1] = len(raw.times) + end = len(raw.times) + 1 + time_idxs = np.arange(0, end, int(round(t_window * raw.info["sfreq"]))) + if len(time_idxs) == 1: + time_idxs = np.concatenate([time_idxs, [end]]) + if time_idxs[-1] != end: + time_idxs[-1] = end count = 0 locs = np.array([ch["loc"] for ch in info["chs"]]) zs = locs[mag_picks, -3:].copy() @@ -388,9 +389,11 @@ def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit each_err = _data_err(data, S_tot, cals, axis=-1)[picks_mag] n_bad = (each_err > err_limit).sum() if n_bad: + bad_max = np.argmax(each_err) reason.append( f"{n_bad} residual{_pl(n_bad)} > {err_limit:0.1f}% " - f"(max: {each_err.max():0.2f}%)" + f"(max: {each_err[bad_max]:0.2f}% @ " + f"{info['ch_names'][picks_mag[bad_max]]})" ) reason = ", ".join(reason) if reason: @@ -398,7 +401,7 @@ def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit good = not bool(reason) assert np.allclose(np.linalg.norm(zs, axis=1), 1.0) logger.info(f" Fit mismatch {first_err:0.2f}→{last_err:0.2f}%") - logger.info(f' Data segment {"" if good else "un"}usable{reason}') + logger.info(f" Data segment {'' if good else 'un'}usable{reason}") # Reformat zs and cals to be the n_mags (including bads) assert zs.shape == (len(data), 3) assert cals.shape == (len(data), 1) diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py new file mode 100755 index 00000000000..be226a73889 --- /dev/null +++ b/mne/preprocessing/_pca_obs.py @@ -0,0 +1,333 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import math + +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +from ..io.fiff.raw import Raw +from ..utils import _PCA, _validate_type, logger, verbose + + +@verbose +def apply_pca_obs( + raw: Raw, + picks: list[str], + *, + qrs_times: np.ndarray, + n_components: int = 4, + n_jobs: int | None = None, + copy: bool = True, + verbose: bool | str | int | None = None, +) -> Raw: + """ + Apply the PCA-OBS algorithm to picks of a Raw object. + + Uses the optimal basis set (OBS) algorithm from :footcite:`NiazyEtAl2005`. + + Parameters + ---------- + raw : instance of Raw + The raw data to process. + %(picks_all_data_noref)s + qrs_times : ndarray, shape (n_peaks,) + Array of times in the Raw data of detected R-peaks in ECG channel. + n_components : int + Number of PCA components to use to form the OBS (default 4). + %(n_jobs)s + copy : bool + If False, modify the Raw instance in-place. + If True (default), copy the raw instance before processing. + %(verbose)s + + Returns + ------- + raw : instance of Raw + The modified raw instance. + + Notes + ----- + .. versionadded:: 1.10 + + References + ---------- + .. footbibliography:: + """ + # sanity checks + _validate_type(qrs_times, np.ndarray, "qrs_times") + if len(qrs_times.shape) > 1: + raise ValueError("qrs_times must be a 1d array") + if qrs_times.dtype not in [int, float]: + raise ValueError("qrs_times must be an array of either integers or floats") + if np.any(qrs_times < 0): + raise ValueError("qrs_times must be strictly positive") + if np.any(qrs_times >= raw.times[-1]): + logger.warning("some out of bound qrs_times will be ignored..") + + if copy: + raw = raw.copy() + + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS, convert times to indices + qrs=raw.time_as_index(qrs_times), + n_components=n_components, + ) + + return raw + + +def _pca_obs( + data: np.ndarray, + qrs: np.ndarray, + n_components: int, +) -> np.ndarray: + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" + # set to baseline + data = data - np.mean(data) + + # Allocate memory for artifact which will be subtracted from the data + fitted_art = np.zeros(data.shape) + + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] + peak_count = len(peak_idx) + + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## + # define peak range based on RR + mRR = np.median(np.diff(peak_idx)) + peak_range = round(mRR / 2) # Rounds to an integer + mid_p = peak_range + 1 + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): + peak_count = peak_count - 1 # reduce number of QRS complexes detected + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] + + # detrending matrix(twice) + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect: np.ndarray = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] + + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) + pca = _PCA() + pca.fit(dpcamat) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) + + # define selected number of components using profile likelihood + + ##################################### + # Make template of the ECG artefact # + ##################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] + + ################ + # Data Fitting # + ################ + window_start_idx = [] + window_end_idx = [] + post_idx_next_peak = None + + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if post_range > peak_range: + post_range = peak_range + + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with last edge of data + elif p == peak_count - 1: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with middle portion of data + else: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Actually subtract the artefact, return needs to be the same shape as input data + data -= fitted_art + return data + + +def _fit_ecg_template( + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: int | None, + n_samples_fit: int, +) -> tuple[np.ndarray, int]: + """ + Fits the heartbeat artefact found in the data. + + Returns the fitted artefact and the index of the next peak. + + Parameters + ---------- + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix + a_peak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data + post_idx_previous_peak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events + + Returns + ------- + tuple[np.ndarray, int]: the fitted artifact and the next peak index + """ + # post_idx_next_peak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then next_peak is returned at the end and the process repeats + # select window of template + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] + + # select window of data and detrend it + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range + ].T + + # if last peak, return + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx + post_range + + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( + int + ) # interpolation window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window + # endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) + x_fit = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), + ] + ) + y_fit = fitted_art[x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol + + return fitted_art, a_peak_idx + post_range diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index 0a4c8b6a24d..8674d6e22b3 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -213,7 +213,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "ω >= %5.1f°/s (max: %0.1f°/s)", + "Omitting %5.1f%% (%3d segments): ω >= %5.1f°/s (max: %0.1f°/s)", bad_pct, len(onsets), rotation_velocity_limit, @@ -233,7 +233,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "v >= %5.4fm/s (max: %5.4fm/s)", + "Omitting %5.1f%% (%3d segments): v >= %5.4fm/s (max: %5.4fm/s)", bad_pct, len(onsets), translation_velocity_limit, @@ -286,7 +286,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "disp >= %5.4fm (max: %5.4fm)", + "Omitting %5.1f%% (%3d segments): disp >= %5.4fm (max: %5.4fm)", bad_pct, len(onsets), mean_distance_limit, @@ -539,7 +539,7 @@ def annotate_break( if ignore: logger.info( f"Ignoring annotations with descriptions starting " - f'with: {", ".join(ignore)}' + f"with: {', '.join(ignore)}" ) else: annotations = annotations_from_events( diff --git a/mne/preprocessing/eog.py b/mne/preprocessing/eog.py index 20e5481f89c..13b6f2ef672 100644 --- a/mne/preprocessing/eog.py +++ b/mne/preprocessing/eog.py @@ -213,12 +213,12 @@ def _get_eog_channel_index(ch_name, inst): if not_found: raise ValueError( f"The specified EOG channel{_pl(not_found)} " - f'cannot be found: {", ".join(not_found)}' + f"cannot be found: {', '.join(not_found)}" ) eog_inds = pick_channels(inst.ch_names, include=ch_names) - logger.info(f'Using EOG channel{_pl(ch_names)}: {", ".join(ch_names)}') + logger.info(f"Using EOG channel{_pl(ch_names)}: {', '.join(ch_names)}") return eog_inds diff --git a/mne/preprocessing/hfc.py b/mne/preprocessing/hfc.py index f8a65510a9a..41bf6bbd232 100644 --- a/mne/preprocessing/hfc.py +++ b/mne/preprocessing/hfc.py @@ -68,8 +68,7 @@ def compute_proj_hfc( n_chs = len(coils[5]) if n_chs != info["nchan"]: raise ValueError( - f'Only {n_chs}/{info["nchan"]} picks could be interpreted ' - "as MEG channels." + f"Only {n_chs}/{info['nchan']} picks could be interpreted as MEG channels." ) S = _sss_basis(exp, coils) del coils diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 3ea11e0531e..f35fe24c1ee 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -560,7 +560,7 @@ def __repr__(self): """ICA fit information.""" infos = self._get_infos_for_repr() - s = f'{infos.fit_on or "no"} decomposition, method: {infos.fit_method}' + s = f"{infos.fit_on or 'no'} decomposition, method: {infos.fit_method}" if infos.fit_on is not None: s += ( @@ -568,8 +568,8 @@ def __repr__(self): f"{infos.fit_n_samples} samples), " f"{infos.fit_n_components} ICA components " f"({infos.fit_n_pca_components} PCA components available), " - f'channel types: {", ".join(infos.ch_types)}, ' - f'{len(infos.excludes) or "no"} sources marked for exclusion' + f"channel types: {', '.join(infos.ch_types)}, " + f"{len(infos.excludes) or 'no'} sources marked for exclusion" ) return f"" @@ -698,7 +698,7 @@ def fit( warn( f"The following parameters passed to ICA.fit() will be " f"ignored, as they only affect raw data (and it appears " - f'you passed epochs): {", ".join(ignored_params)}' + f"you passed epochs): {', '.join(ignored_params)}" ) picks = _picks_to_idx( @@ -875,7 +875,7 @@ def _do_proj(self, data, log_suffix=""): logger.info( f" Applying projection operator with {nproj} " f"vector{_pl(nproj)}" - f'{" " if log_suffix else ""}{log_suffix}' + f"{' ' if log_suffix else ''}{log_suffix}" ) if self.noise_cov is None: # otherwise it's in pre_whitener_ data = proj @ data @@ -1162,7 +1162,7 @@ def get_explained_variance_ratio(self, inst, *, components=None, ch_type=None): raise ValueError( f"You requested operation on the channel type " f'"{ch_type}", but only the following channel types are ' - f'supported: {", ".join(allowed_ch_types)}' + f"supported: {', '.join(allowed_ch_types)}" ) del ch_type @@ -2393,8 +2393,7 @@ def _pick_sources(self, data, include, exclude, n_pca_components): unmixing = np.dot(unmixing, pca_components) logger.info( - f" Projecting back using {_n_pca_comp} " - f"PCA component{_pl(_n_pca_comp)}" + f" Projecting back using {_n_pca_comp} PCA component{_pl(_n_pca_comp)}" ) mixing = np.eye(_n_pca_comp) mixing[: self.n_components_, : self.n_components_] = self.mixing_matrix_ @@ -3368,8 +3367,7 @@ def corrmap( is_subject = False else: raise ValueError( - "`template` must be a length-2 tuple or an array the " - "size of the ICA maps." + "`template` must be a length-2 tuple or an array the size of the ICA maps." ) template_fig, labelled_ics = None, None diff --git a/mne/preprocessing/ieeg/_volume.py b/mne/preprocessing/ieeg/_volume.py index b4997b2e3f8..af2dcf4328b 100644 --- a/mne/preprocessing/ieeg/_volume.py +++ b/mne/preprocessing/ieeg/_volume.py @@ -109,7 +109,7 @@ def _warn_missing_chs(info, dig_image, after_warp=False): if missing_ch: warn( f"Channel{_pl(missing_ch)} " - f'{", ".join(repr(ch) for ch in missing_ch)} not assigned ' + f"{', '.join(repr(ch) for ch in missing_ch)} not assigned " "voxels " + (f" after applying {after_warp}" if after_warp else "") ) diff --git a/mne/preprocessing/infomax_.py b/mne/preprocessing/infomax_.py index f0722ce5267..b445ac7116c 100644 --- a/mne/preprocessing/infomax_.py +++ b/mne/preprocessing/infomax_.py @@ -320,8 +320,7 @@ def infomax( if l_rate > min_l_rate: if verbose: logger.info( - f"... lowering learning rate to {l_rate:g}" - "\n... re-starting..." + f"... lowering learning rate to {l_rate:g}\n... re-starting..." ) else: raise ValueError( diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index e1fb548caa5..8c9c0a93957 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -10,7 +10,7 @@ import numpy as np from scipy import linalg -from scipy.special import lpmv, sph_harm +from scipy.special import lpmv from .. import __version__ from .._fiff.compensator import make_compensator @@ -24,7 +24,7 @@ from ..annotations import _annotations_starts_stops from ..bem import _check_origin from ..channels.channels import _get_T1T2_mag_inds, fix_mag_coil_types -from ..fixes import _safe_svd, bincount +from ..fixes import _safe_svd, bincount, sph_harm_y from ..forward import _concatenate_coils, _create_meg_coils, _prep_meg_channels from ..io import BaseRaw, RawArray from ..surface import _normalize_vectors @@ -436,7 +436,7 @@ def _prep_maxwell_filter( # we purposefully stay away from shorthand notation in both and use # explicit terms (like 'azimuth' and 'polar') to avoid confusion. # See mathworld.wolfram.com/SphericalHarmonic.html for more discussion. - # Our code follows the same standard that ``scipy`` uses for ``sph_harm``. + # Our code follows the same standard that ``scipy`` uses for ``sph_harm_y``. # triage inputs ASAP to avoid late-thrown errors _validate_type(raw, BaseRaw, "raw") @@ -507,7 +507,7 @@ def _prep_maxwell_filter( extended_proj_.append(proj["data"]["data"][:, idx]) extended_proj = np.concatenate(extended_proj_) logger.info( - " Extending external SSS basis using %d projection " "vectors", + " Extending external SSS basis using %d projection vectors", len(extended_proj), ) @@ -566,8 +566,8 @@ def _prep_maxwell_filter( dist = np.sqrt(np.sum(_sq(diff))) if dist > 25.0: warn( - f'Head position change is over 25 mm ' - f'({", ".join(f"{x:0.1f}" for x in diff)}) = {dist:0.1f} mm' + f"Head position change is over 25 mm " + f"({', '.join(f'{x:0.1f}' for x in diff)}) = {dist:0.1f} mm" ) # Reconstruct raw file object with spatiotemporal processed data @@ -1487,7 +1487,7 @@ def _sss_basis_basic(exp, coils, mag_scale=100.0, method="standard"): S_in_out = list() grads_in_out = list() # Same spherical harmonic is used for both internal and external - sph = sph_harm(order, degree, az, pol) + sph = sph_harm_y(degree, order, pol, az) sph_norm = _sph_harm_norm(order, degree) # Compute complex gradient for all integration points # in spherical coordinates (Eq. 6). The gradient for rad, az, pol @@ -2579,7 +2579,7 @@ def find_bad_channels_maxwell( freq_loc = "below" if raw.info["lowpass"] < h_freq else "equal to" msg = ( f"The input data has already been low-pass filtered with a " - f'{raw.info["lowpass"]} Hz cutoff frequency, which is ' + f"{raw.info['lowpass']} Hz cutoff frequency, which is " f"{freq_loc} the requested cutoff of {h_freq} Hz. Not " f"applying low-pass filter." ) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index 92a2e55b9fb..c17cf31110c 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -76,7 +76,7 @@ def beer_lambert_law(raw, ppf=6.0): for ki, kind in zip((ii, jj), ("hbo", "hbr")): ch = raw.info["chs"][ki] ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL) - new_name = f'{ch["ch_name"].split(" ")[0]} {kind}' + new_name = f"{ch['ch_name'].split(' ')[0]} {kind}" rename[ch["ch_name"]] = new_name raw.rename_channels(rename) diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 02c596bf4bc..8b45208e848 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -20,7 +20,7 @@ ) from mne.preprocessing.tests.test_maxwell import _assert_shielding from mne.transforms import _angle_dist_between_rigid -from mne.utils import object_diff +from mne.utils import catch_logging, object_diff # Define fine calibration filepaths data_path = testing.data_path(download=False) @@ -231,7 +231,7 @@ def test_fine_cal_systems(system, tmp_path): err_limit = 6000 n_ref = 28 corrs = (0.19, 0.41, 0.49) - sfs = [0.5, 0.7, 0.9, 1.5] + sfs = [0.5, 0.7, 0.9, 1.55] corr_tol = 0.55 elif system == "fil": raw = read_raw_fil(fil_fname, verbose="error") @@ -289,3 +289,15 @@ def test_fine_cal_systems(system, tmp_path): got_corrs = np.corrcoef([raw_data, raw_sss_data, raw_sss_cal_data]) got_corrs = got_corrs[np.triu_indices(3, 1)] assert_allclose(got_corrs, corrs, atol=corr_tol) + if system == "fil": + with catch_logging(verbose=True) as log: + compute_fine_calibration( + raw.copy().crop(0, 0.12).pick(raw.ch_names[:12]), + t_window=0.06, # 2 segments + angle_limit=angle_limit, + err_limit=err_limit, + ext_order=2, + verbose=True, + ) + log = log.getvalue() + assert "(averaging over 2 time intervals)" in log, log diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 102900bb1fc..002d4555ff8 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -11,7 +11,6 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal from scipy import sparse -from scipy.special import sph_harm import mne from mne import compute_raw_covariance, concatenate_raws, pick_info, pick_types @@ -19,6 +18,7 @@ from mne.annotations import _annotations_starts_stops from mne.chpi import filter_chpi, read_head_pos from mne.datasets import testing +from mne.fixes import sph_harm_y from mne.forward import _prep_meg_channels, use_coil_def from mne.io import ( BaseRaw, @@ -431,9 +431,9 @@ def test_spherical_conversions(): az, pol = np.meshgrid(np.linspace(0, 2 * np.pi, 30), np.linspace(0, np.pi, 20)) for degree in range(1, int_order): for order in range(0, degree + 1): - sph = sph_harm(order, degree, az, pol) + sph = sph_harm_y(degree, order, pol, az) # ensure that we satisfy the conjugation property - assert_allclose(_sh_negate(sph, order), sph_harm(-order, degree, az, pol)) + assert_allclose(_sh_negate(sph, order), sph_harm_y(degree, -order, pol, az)) # ensure our conversion functions work sph_real_pos = _sh_complex_to_real(sph, order) sph_real_neg = _sh_complex_to_real(sph, -order) @@ -980,9 +980,9 @@ def _assert_shielding(raw_sss, erm_power, min_factor, max_factor=np.inf, meg="ma sss_power = raw_sss[picks][0].ravel() sss_power = np.sqrt(np.sum(sss_power * sss_power)) factor = erm_power / sss_power - assert ( - min_factor <= factor < max_factor - ), f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" + assert min_factor <= factor < max_factor, ( + f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" + ) @buggy_mkl_svd diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py new file mode 100644 index 00000000000..ee2568a2080 --- /dev/null +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -0,0 +1,107 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from pathlib import Path + +import numpy as np +import pytest + +from mne.io import read_raw_fif +from mne.io.fiff.raw import Raw +from mne.preprocessing import apply_pca_obs + +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" + + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True) + + +def test_heart_artifact_removal(short_raw_data: Raw): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + pd = pytest.importorskip("pandas") + + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) + + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1] + + # perform heart artifact removal + short_raw_data = apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1 + ) + + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() + + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) + + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs_times", "error", "exception"), + [ + ( + ["eeg"], + np.array([[0, 1], [2, 3]]), + "qrs_times must be a 1d array", + ValueError, + ), + ( + ["eeg"], + [2, 3, 4], + "qrs_times must be an instance of ndarray, got instead.", + TypeError, + ), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_times must be an array of either integers or floats", + ValueError, + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_times must be strictly positive", + ValueError, + ), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, + picks: list[str], + qrs_times: np.ndarray, + error: str, + exception: type[Exception], +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(exception, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times) diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 0b1132761b1..606b49370df 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -198,8 +198,7 @@ def _fit_xdawn( evals, evecs = linalg.eigh(evo_cov, signal_cov) except np.linalg.LinAlgError as exp: raise ValueError( - "Could not compute eigenvalues, ensure " - f"proper regularization ({exp})" + f"Could not compute eigenvalues, ensure proper regularization ({exp})" ) evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs) diff --git a/mne/report/report.py b/mne/report/report.py index 732c1a5c8b3..852feebc638 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -324,7 +324,7 @@ def _check_tags(tags) -> tuple[str]: raise TypeError( f"All tags must be strings without spaces or special characters, " f"but got the following instead: " - f'{", ".join([str(tag) for tag in bad_tags])}' + f"{', '.join([str(tag) for tag in bad_tags])}" ) # Check for invalid characters @@ -338,7 +338,7 @@ def _check_tags(tags) -> tuple[str]: if bad_tags: raise ValueError( f"The following tags contained invalid characters: " - f'{", ".join(repr(tag) for tag in bad_tags)}' + f"{', '.join(repr(tag) for tag in bad_tags)}" ) return tags @@ -429,8 +429,7 @@ def _fig_to_img( output = BytesIO() dpi = fig.get_dpi() logger.debug( - f"Saving figure with dimension {fig.get_size_inches()} inches with " - f"{dpi} dpi" + f"Saving figure with dimension {fig.get_size_inches()} inches with {dpi} dpi" ) # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html @@ -913,7 +912,7 @@ def __repr__(self): if len(titles) > 0: titles = [f" {t}" for t in titles] # indent tr = max(len(s), 50) # trim to larger of opening str and 50 - titles = [f"{t[:tr - 2]} …" if len(t) > tr else t for t in titles] + titles = [f"{t[: tr - 2]} …" if len(t) > tr else t for t in titles] # then trim to the max length of all of these tr = max(len(title) for title in titles) tr = max(tr, len(s)) @@ -2761,9 +2760,7 @@ def _init_render(self, verbose=None): if inc_fname.endswith(".js"): include.append( - f'" + f'' ) elif inc_fname.endswith(".css"): include.append(f'') @@ -3649,7 +3646,7 @@ def _add_evoked_joint( ) ) - title = f'Time course ({_handle_default("titles")[ch_type]})' + title = f"Time course ({_handle_default('titles')[ch_type]})" self._add_figure( fig=fig, title=title, @@ -4121,7 +4118,7 @@ def _add_epochs( assert "eeg" in ch_type title_start = "ERP image" - title = f'{title_start} ({_handle_default("titles")[ch_type]})' + title = f"{title_start} ({_handle_default('titles')[ch_type]})" self._add_figure( fig=fig, diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 39205e28af2..deeb3a43ede 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -764,6 +764,7 @@ def plot( transparent=True, alpha=1.0, time_viewer="auto", + *, subjects_dir=None, figure=None, views="auto", @@ -1387,8 +1388,7 @@ def transform(self, func, idx=None, tmin=None, tmax=None, copy=False): ] else: raise ValueError( - "copy must be True if transformed data has " - "more than 2 dimensions" + "copy must be True if transformed data has more than 2 dimensions" ) else: # return new or overwritten stc @@ -2256,6 +2256,7 @@ def plot( vector_alpha=1.0, scale_factor=None, time_viewer="auto", + *, subjects_dir=None, figure=None, views="lateral", @@ -2267,6 +2268,7 @@ def plot( foreground=None, initial_time=None, time_unit="s", + title=None, show_traces="auto", src=None, volume_options=1.0, @@ -2299,6 +2301,7 @@ def plot( foreground=foreground, initial_time=initial_time, time_unit=time_unit, + title=title, show_traces=show_traces, src=src, volume_options=volume_options, @@ -2767,6 +2770,7 @@ def plot_3d( vector_alpha=1.0, scale_factor=None, time_viewer="auto", + *, subjects_dir=None, figure=None, views="axial", @@ -2778,6 +2782,7 @@ def plot_3d( foreground=None, initial_time=None, time_unit="s", + title=None, show_traces="auto", src=None, volume_options=1.0, @@ -2810,6 +2815,7 @@ def plot_3d( foreground=foreground, initial_time=initial_time, time_unit=time_unit, + title=title, show_traces=show_traces, src=src, volume_options=volume_options, @@ -3626,7 +3632,7 @@ def _volume_labels(src, labels, mri_resolution): ] nnz = sum(len(v) != 0 for v in vertices) logger.info( - "%d/%d atlas regions had at least one vertex " "in the source space", + "%d/%d atlas regions had at least one vertex in the source space", nnz, len(out_labels), ) @@ -3999,7 +4005,7 @@ def stc_near_sensors( min_dist = pdist(pos).min() * 1000 logger.info( - f' Minimum {"projected " if project else ""}intra-sensor distance: ' + f" Minimum {'projected ' if project else ''}intra-sensor distance: " f"{min_dist:0.1f} mm" ) @@ -4027,7 +4033,7 @@ def stc_near_sensors( if len(missing): warn( f"Channel{_pl(missing)} missing in STC: " - f'{", ".join(evoked.ch_names[mi] for mi in missing)}' + f"{', '.join(evoked.ch_names[mi] for mi in missing)}" ) nz_data = w @ evoked.data diff --git a/mne/source_space/_source_space.py b/mne/source_space/_source_space.py index f5e8b76a1fa..d64989961cf 100644 --- a/mne/source_space/_source_space.py +++ b/mne/source_space/_source_space.py @@ -743,7 +743,7 @@ def export_volume( # generate use warnings for clipping if n_diff > 0: warn( - f'{n_diff} {src["type"]} vertices lay outside of volume ' + f"{n_diff} {src['type']} vertices lay outside of volume " f"space. Consider using a larger volume space." ) # get surface id or use default value @@ -1546,7 +1546,7 @@ def setup_source_space( # pre-load ico/oct surf (once) for speed, if necessary if stype not in ("spacing", "all"): logger.info( - f'Doing the {dict(ico="icosa", oct="octa")[stype]}hedral vertex picking...' + f"Doing the {dict(ico='icosa', oct='octa')[stype]}hedral vertex picking..." ) for hemi, surf in zip(["lh", "rh"], surfs): logger.info(f"Loading {surf}...") @@ -2916,8 +2916,7 @@ def _get_vertex_map_nn( raise RuntimeError(f"vertex {one} would be used multiple times.") one = one[0] logger.info( - "Source space vertex moved from %d to %d because of " - "double occupation.", + "Source space vertex moved from %d to %d because of double occupation.", was, one, ) @@ -3167,8 +3166,7 @@ def _compare_source_spaces(src0, src1, mode="exact", nearest=True, dist_tol=1.5e assert_array_equal( s["vertno"], np.where(s["inuse"])[0], - f'src{ii}[{si}]["vertno"] != ' - f'np.where(src{ii}[{si}]["inuse"])[0]', + f'src{ii}[{si}]["vertno"] != np.where(src{ii}[{si}]["inuse"])[0]', ) assert_equal(len(s0["vertno"]), len(s1["vertno"])) agreement = np.mean(s0["inuse"] == s1["inuse"]) diff --git a/mne/surface.py b/mne/surface.py index 21432e7edfd..9e24147a080 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -214,7 +214,7 @@ def get_meg_helmet_surf(info, trans=None, *, verbose=None): ] ) logger.info( - "Getting helmet for system %s (derived from %d MEG " "channel locations)", + "Getting helmet for system %s (derived from %d MEG channel locations)", system, len(rr), ) @@ -733,7 +733,7 @@ def __init__(self, surf, *, mode="old", verbose=None): else: self._init_old() logger.debug( - f'Setting up {mode} interior check for {len(self.surf["rr"])} ' + f"Setting up {mode} interior check for {len(self.surf['rr'])} " f"points took {(time.time() - t0) * 1000:0.1f} ms" ) @@ -761,8 +761,7 @@ def _init_pyvista(self): def __call__(self, rr, n_jobs=None, verbose=None): n_orig = len(rr) logger.info( - f"Checking surface interior status for " - f'{n_orig} point{_pl(n_orig, " ")}...' + f"Checking surface interior status for {n_orig} point{_pl(n_orig, ' ')}..." ) t0 = time.time() if self.mode == "pyvista": @@ -770,7 +769,7 @@ def __call__(self, rr, n_jobs=None, verbose=None): else: inside = self._call_old(rr, n_jobs) n = inside.sum() - logger.info(f' Total {n}/{n_orig} point{_pl(n, " ")} inside the surface') + logger.info(f" Total {n}/{n_orig} point{_pl(n, ' ')} inside the surface") logger.info(f"Interior check completed in {(time.time() - t0) * 1000:0.1f} ms") return inside @@ -792,7 +791,7 @@ def _call_old(self, rr, n_jobs): n = (in_mask).sum() n_pad = str(n).rjust(prec) logger.info( - f' Found {n_pad}/{n_orig} point{_pl(n, " ")} ' + f" Found {n_pad}/{n_orig} point{_pl(n, ' ')} " f"inside an interior sphere of radius " f"{1000 * self.inner_r:6.1f} mm" ) @@ -801,7 +800,7 @@ def _call_old(self, rr, n_jobs): n = (out_mask).sum() n_pad = str(n).rjust(prec) logger.info( - f' Found {n_pad}/{n_orig} point{_pl(n, " ")} ' + f" Found {n_pad}/{n_orig} point{_pl(n, ' ')} " f"outside an exterior sphere of radius " f"{1000 * self.outer_r:6.1f} mm" ) @@ -818,7 +817,7 @@ def _call_old(self, rr, n_jobs): n_pad = str(n).rjust(prec) check_pad = str(len(del_outside)).rjust(prec) logger.info( - f' Found {n_pad}/{check_pad} point{_pl(n, " ")} outside using ' + f" Found {n_pad}/{check_pad} point{_pl(n, ' ')} outside using " "surface Qhull" ) @@ -828,7 +827,7 @@ def _call_old(self, rr, n_jobs): n_pad = str(n).rjust(prec) check_pad = str(len(solid_outside)).rjust(prec) logger.info( - f' Found {n_pad}/{check_pad} point{_pl(n, " ")} outside using ' + f" Found {n_pad}/{check_pad} point{_pl(n, ' ')} outside using " "solid angles" ) inside[idx[solid_outside]] = False diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6b1356ae107..4d0db170e2a 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1450,8 +1450,7 @@ def test_repr(): # long annotation repr (> 79 characters, will be shortened) r = repr(Annotations(range(14), [0] * 14, list("abcdefghijklmn"))) assert r == ( - "" + "" ) # empty Annotations diff --git a/mne/tests/test_dipole.py b/mne/tests/test_dipole.py index e93d4031646..f230eaa4256 100644 --- a/mne/tests/test_dipole.py +++ b/mne/tests/test_dipole.py @@ -214,9 +214,9 @@ def test_dipole_fitting(tmp_path): # Sanity check: do our residuals have less power than orig data? data_rms = np.sqrt(np.sum(evoked.data**2, axis=0)) resi_rms = np.sqrt(np.sum(residual.data**2, axis=0)) - assert ( - data_rms > resi_rms * 0.95 - ).all(), f"{(data_rms / resi_rms).min()} (factor: {0.95})" + assert (data_rms > resi_rms * 0.95).all(), ( + f"{(data_rms / resi_rms).min()} (factor: {0.95})" + ) # Compare to original points transform_surface_to(fwd["src"][0], "head", fwd["mri_head_t"]) diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index c94da5e5ab8..64f80f50b74 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -222,8 +222,7 @@ def test_tabs(): continue source = inspect.getsource(mod) assert "\t" not in source, ( - f'"{modname}" has tabs, please remove them ' - "or add it to the ignore list" + f'"{modname}" has tabs, please remove them or add it to the ignore list' ) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 079a2b53ec9..aa11082238f 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" - assert all( - isinstance(s, str) for log in drop_log for s in log - ), "drop_log[ii][jj] should be str" + assert all(isinstance(log, tuple) for log in drop_log), ( + "drop_log[ii] should be tuple" + ) + assert all(isinstance(s, str) for log in drop_log for s in log), ( + "drop_log[ii][jj] should be str" + ) def test_reject(): diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index e259ececbce..537f1930f45 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -90,9 +90,9 @@ def test_estimate_ringing(): (0.0001, (30000, 60000)), ): # 37993 n_ring = estimate_ringing_samples(butter(3, thresh, output=kind)) - assert ( - lims[0] <= n_ring <= lims[1] - ), f"{kind} {thresh}: {lims[0]} <= {n_ring} <= {lims[1]}" + assert lims[0] <= n_ring <= lims[1], ( + f"{kind} {thresh}: {lims[0]} <= {n_ring} <= {lims[1]}" + ) with pytest.warns(RuntimeWarning, match="properly estimate"): assert estimate_ringing_samples(butter(4, 0.00001)) == 100000 diff --git a/mne/tests/test_parallel.py b/mne/tests/test_parallel.py index f72f3281a59..a780f32f911 100644 --- a/mne/tests/test_parallel.py +++ b/mne/tests/test_parallel.py @@ -26,7 +26,7 @@ def test_parallel_func(n_jobs): """Test Parallel wrapping.""" joblib = pytest.importorskip("joblib") if os.getenv("MNE_FORCE_SERIAL", "").lower() in ("true", "1"): - pytest.skip("MNE_FORCE_SERIAL cannot be set") + pytest.skip("MNE_FORCE_SERIAL is set") def fun(x): return x * 2 diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index 0faeb7263d8..6b53c39a98b 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -11,6 +11,8 @@ __all__ = [ "RawTFRArray", "Spectrum", "SpectrumArray", + "combine_spectrum", + "combine_tfr", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -61,6 +63,7 @@ from .spectrum import ( EpochsSpectrumArray, Spectrum, SpectrumArray, + combine_spectrum, read_spectrum, ) from .tfr import ( @@ -71,6 +74,7 @@ from .tfr import ( EpochsTFRArray, RawTFR, RawTFRArray, + combine_tfr, fwhm, morlet, read_tfrs, diff --git a/mne/time_frequency/_stft.py b/mne/time_frequency/_stft.py index 8fb80b43fcc..a6b6f23fff7 100644 --- a/mne/time_frequency/_stft.py +++ b/mne/time_frequency/_stft.py @@ -59,8 +59,7 @@ def stft(x, wsize, tstep=None, verbose=None): if (wsize % tstep) or (tstep % 2): raise ValueError( - "The step size must be a multiple of 2 and a " - "divider of the window length." + "The step size must be a multiple of 2 and a divider of the window length." ) if tstep > wsize / 2: diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index c858dd52e57..4ddaa0ac6a3 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -224,8 +224,7 @@ def sum(self, fmin=None, fmax=None): """ if self._is_sum: raise RuntimeError( - "This CSD matrix already represents a mean or " - "sum across frequencies." + "This CSD matrix already represents a mean or sum across frequencies." ) # Deal with the various ways in which fmin and fmax can be specified @@ -1372,7 +1371,7 @@ def _execute_csd_function( logger.info("[done]") if ch_names is None: - ch_names = [f"SERIES{i+1:03}" for i in range(n_channels)] + ch_names = [f"SERIES{i + 1:03}" for i in range(n_channels)] return CrossSpectralDensity( csds_mean, diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..98705e838c2 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -471,6 +471,7 @@ def tfr_array_multitaper( output="complex", n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -504,6 +505,11 @@ def tfr_array_multitaper( coherence across trials. %(n_jobs)s The parallelization is implemented across channels. + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. + + .. versionadded:: 1.10.0 %(verbose)s Returns @@ -520,6 +526,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + ``return_weights=True``. See Also -------- @@ -550,6 +559,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index a70697fd57c..03a57010061 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -311,7 +311,7 @@ def __init__( if np.isfinite(fmax) and (fmax > self.sfreq / 2): raise ValueError( f"Requested fmax ({fmax} Hz) must not exceed ½ the sampling " - f'frequency of the data ({0.5 * inst.info["sfreq"]} Hz).' + f"frequency of the data ({0.5 * inst.info['sfreq']} Hz)." ) # method self._inst_type = type(inst) @@ -442,7 +442,7 @@ def _check_values(self): if bad_value.any(): chs = np.array(self.ch_names)[bad_value].tolist() s = _pl(bad_value.sum()) - warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) + warn(f"Zero value in spectrum for channel{s} {', '.join(chs)}", UserWarning) def _returns_complex_tapers(self, **method_kw): return self.method == "multitaper" and method_kw.get("output") == "complex" @@ -1536,7 +1536,7 @@ def average(self, method="mean"): state["nave"] = state["data"].shape[0] state["data"] = method(state["data"]) state["dims"] = state["dims"][1:] - state["data_type"] = f'Averaged {state["data_type"]}' + state["data_type"] = f"Averaged {state['data_type']}" defaults = dict( method=None, fmin=None, @@ -1643,6 +1643,74 @@ def __init__( ) +def combine_spectrum(all_spectrum, weights="nave"): + """Merge spectral data by weighted addition. + + Create a new :class:`mne.time_frequency.Spectrum` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., ``[1, -1]``). + Data must have the same channels and the same frequencies. + + Parameters + ---------- + all_spectrum : list of Spectrum + The Spectrum objects. + weights : list of float | str + The weights to apply to the data of each :class:`~mne.time_frequency.Spectrum` + instance, or a string describing the weighting strategy to apply: 'nave' + computes sum-to-one weights proportional to each object’s nave attribute; + 'equal' weights each :class:`~mne.time_frequency.Spectrum` by + ``1 / len(all_spectrum)``. + + Returns + ------- + spectrum : Spectrum + The new spectral data. + + Notes + ----- + .. versionadded:: 1.10.0 + """ + spectrum = all_spectrum[0].copy() + if isinstance(weights, str): + if weights not in ("nave", "equal"): + raise ValueError('Weights must be a list of float, or "nave" or "equal"') + if weights == "nave": + for s_ in all_spectrum: + if s_.nave is None: + raise ValueError(f"The 'nave' attribute is not specified for {s_}") + weights = np.array([e.nave for e in all_spectrum], float) + weights /= weights.sum() + else: # == 'equal' + weights = [1.0 / len(all_spectrum)] * len(all_spectrum) + weights = np.array(weights, float) + if weights.ndim != 1 or weights.size != len(all_spectrum): + raise ValueError("Weights must be the same size as all_spectrum") + + ch_names = spectrum.ch_names + for s_ in all_spectrum[1:]: + assert s_.ch_names == ch_names, ( + f"{spectrum} and {s_} do not contain the same channels" + ) + assert np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7, ( + f"{spectrum} and {s_} do not contain the same frequencies" + ) + + # use union of bad channels + bads = list( + set(spectrum.info["bads"]).union(*(s_.info["bads"] for s_ in all_spectrum[1:])) + ) + spectrum.info["bads"] = bads + + # combine spectral data + spectrum._data = sum(w * s_.data for w, s_ in zip(weights, all_spectrum)) + if spectrum.nave is not None: + spectrum._nave = max( + int(1.0 / sum(w**2 / s_.nave for w, s_ in zip(weights, all_spectrum))), 1 + ) + return spectrum + + def read_spectrum(fname): """Load a :class:`mne.time_frequency.Spectrum` object from disk. diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 162d89b1c25..927c22360c5 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -14,7 +14,11 @@ from mne.io import RawArray from mne.time_frequency import read_spectrum from mne.time_frequency.multitaper import _psd_from_mt -from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray +from mne.time_frequency.spectrum import ( + EpochsSpectrumArray, + SpectrumArray, + combine_spectrum, +) from mne.utils import _record_warnings @@ -190,6 +194,55 @@ def test_spectrum_copy(raw_spectrum): assert raw_spectrum.freqs is not None +@pytest.mark.parametrize("weights", ["nave", "equal", [1, -1]]) +def test_combine_spectrum(raw_spectrum, weights): + """Test `combine_spectrum()` works.""" + spectrum1 = raw_spectrum.copy() + spectrum2 = raw_spectrum.copy() + if weights == "nave": + spectrum1._nave = 1 + spectrum2._nave = 2 + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * (5 / 3)) + elif weights == "equal": + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * 1.5) + else: + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, 0) + + +def test_combine_spectrum_error_catch(raw_spectrum): + """Test `combine_spectrum()` catches errors.""" + # Test bad weights + with pytest.raises( + ValueError, match='Weights must be a list of float, or "nave" or "equal"' + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights="foo") + with pytest.raises( + ValueError, match="Weights must be the same size as all_spectrum" + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights=[1, 1, 1]) + + # Test bad nave + with pytest.raises(ValueError, match="The 'nave' attribute is not specified"): + combine_spectrum([raw_spectrum, raw_spectrum], weights="nave") + + # Test inconsistent channels + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2.drop_channels(raw_spectrum2.ch_names[0]) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + # Test inconsistent frequencies + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2._freqs = raw_spectrum2._freqs + 1 + with pytest.raises(AssertionError, match=".* do not contain the same frequencies"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + def test_spectrum_reject_by_annot(raw): """Test rejecting by annotation. diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index e68ea9e6e18..6adb4e361e1 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -255,20 +255,25 @@ def test_tfr_morlet(): # computed within the method. assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data) - # test that averaging power across tapers when multitaper with + # test that aggregating power across tapers when multitaper with # output='complex' gives the same as output='power' epoch_data = epochs.get_data() multitaper_power = tfr_array_multitaper( epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power" ) - multitaper_complex = tfr_array_multitaper( - epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex" + multitaper_complex, weights = tfr_array_multitaper( + epoch_data, + epochs.info["sfreq"], + freqs, + n_cycles, + output="complex", + return_weights=True, ) - taper_dim = 2 - power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean( - axis=taper_dim - ) + weights = np.expand_dims(weights, axis=(0, 1, -1)) # match shape of complex data + tfr = weights * multitaper_complex + tfr = (tfr * tfr.conj()).real.sum(axis=2) + power_from_complex = tfr * (2 / (weights * weights.conj()).real.sum(axis=2)) assert_allclose(power_from_complex, multitaper_power) print(itc) # test repr @@ -432,17 +437,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): @@ -664,6 +673,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): with tfr.info._unlock(): tfr.info["meas_date"] = want assert tfr_loaded == tfr + # test with taper dimension and weights + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs + state = tfr.__getstate__() + state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim + state["weights"] = weights # add weights + state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims + tfr = EpochsTFR(inst=state) + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr_loaded == tfr # test overwrite with pytest.raises(OSError, match="Destination file exists."): tfr.save(fname, overwrite=False) @@ -722,17 +742,31 @@ def test_average_tfr_init(full_evoked): AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) -def test_epochstfr_init_errors(epochs_tfr): - """Test __init__ for EpochsTFR.""" - state = epochs_tfr.__getstate__() - with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_init_errors(inst, request, average_tfr): + """Test __init__ for {Raw,Epochs,Average}TFR.""" + # Load data + inst = _get_inst(inst, request, average_tfr=average_tfr) + state = inst.__getstate__() + # Prepare for TFRArray object instantiation + inst_name = inst.__class__.__name__ + class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) + ndims_mapping = dict( + RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") + ) + TFR = class_mapping[inst_name] + allowed_ndims = ndims_mapping[inst_name] + # Check errors caught + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=inst.data[..., 0])) + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): - EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + TFR(inst=state | dict(times=inst.times[:-1])) with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): - EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + TFR(inst=state | dict(freqs=inst.freqs[:-1])) @pytest.mark.parametrize( @@ -830,6 +864,25 @@ def test_plot(): plt.close("all") +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_plot_multitaper_complex_phase(output): + """Test TFR plotting of data with a taper dimension.""" + # Create example data with a taper dimension + n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) + data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) + if output == "complex": + data = data + np.random.rand(*data.shape) * 1j # add imaginary data + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, weights=weights + ) + # Check that plotting works + tfr.plot() + + @pytest.mark.parametrize( "timefreqs,title,combine", ( @@ -1154,6 +1207,15 @@ def test_averaging_epochsTFR(): ): power.average(method=np.mean) + # Check it doesn't run for taper spectra + tapered = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" + ) + with pytest.raises( + NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." + ): + tapered.average() + def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" @@ -1258,12 +1320,15 @@ def test_to_data_frame(): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) srate = 1000.0 - freqs = np.arange(5) + freqs = np.arange(n_freqs) + tapers = np.arange(n_tapers) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 5 + n_epos) @@ -1276,6 +1341,7 @@ def test_to_data_frame(): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1287,10 +1353,21 @@ def test_to_data_frame(): # test wide format df_wide = tfr.to_data_frame() assert all(np.isin(tfr.ch_names, df_wide.columns)) - assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) + assert all( + np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) + ) # test long format df_long = tfr.to_data_frame(long_format=True) - expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") + expected = ( + "condition", + "epoch", + "freq", + "time", + "channel", + "ch_type", + "value", + "taper", + ) assert set(expected) == set(df_long.columns) assert set(tfr.ch_names) == set(df_long["channel"]) assert len(df_long) == tfr.data.size @@ -1298,21 +1375,29 @@ def test_to_data_frame(): df_long = tfr.to_data_frame(long_format=True, index=["freq"]) del df_wide, df_long # test whether data is in correct shape - df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) + df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) data = tfr.data assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) # compare arbitrary observation: assert ( - df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] - == data[1, 3, 1, 2] + df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] + == data[1, 3, 1, 1, 2] ) # Check also for AverageTFR: + # (remove taper dimension before averaging) + state = tfr.__getstate__() + state["data"] = state["data"][:, :, 0] + state["dims"] = ("epoch", "channel", "freq", "time") + state["weights"] = None + tfr = EpochsTFR(inst=state) tfr = tfr.average() with pytest.raises(ValueError, match="options. Valid index options are"): tfr.to_data_frame(index=["epoch", "condition"]) with pytest.raises(ValueError, match='"epoch" is not a valid option'): tfr.to_data_frame(index="epoch") + with pytest.raises(ValueError, match='"taper" is not a valid option'): + tfr.to_data_frame(index="taper") with pytest.raises(TypeError, match="index must be `None` or a string "): tfr.to_data_frame(index=np.arange(400)) # test wide format @@ -1348,11 +1433,13 @@ def test_to_data_frame_index(index): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - freqs = np.arange(5) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) @@ -1365,6 +1452,7 @@ def test_to_data_frame_index(index): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1372,7 +1460,7 @@ def test_to_data_frame_index(index): index = [index] assert list(df.index.names) == index # test that non-indexed data were present as columns - non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) + non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) if len(non_index): assert all(np.isin(non_index, df.columns)) @@ -1538,7 +1626,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): """Test Epochs.compute_tfr(output="complex"/"phase").""" tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) - assert len(tfr.shape) == 5 + assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time + assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match @pytest.mark.parametrize("copy", (False, True)) @@ -1550,6 +1639,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) +@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(obj_type): + """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" + # Create example data with a taper dimension + n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6) + data_shape = (n_chans, n_tapers, n_freqs, n_times) + if obj_type == "epochs": + data_shape = (n_epochs,) + data_shape + data = np.random.rand(*data_shape) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + # Prepare for TFRArray object instantiation + defaults = dict(info=info, data=data, times=times, freqs=freqs) + class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray) + TFRArray = class_mapping[obj_type] + # Check TFRArray instantiation runs with good data + TFRArray(**defaults, weights=weights) + # Check taper dimension but no weights caught + with pytest.raises( + ValueError, match="Taper dimension in data, but no weights found." + ): + TFRArray(**defaults) + # Check mismatching n_taper in weights caught + with pytest.raises( + ValueError, match=r"Taper axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:-1]) + # Check mismatching n_freq in weights caught + with pytest.raises( + ValueError, match=r"Frequency axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:, :-1]) + + def test_tfr_proj(epochs): """Test `compute_tfr(proj=True)`.""" epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) @@ -1731,3 +1856,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): assert re.match( rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() ) + + +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): + """Test plot_joint/topo/topomap() for data with a taper dimension.""" + # Compute TFR with taper dimension + tfr = evoked.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output + ) + # Check that plotting works + tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed + tfr.plot_topo() + tfr.plot_topomap() + + +def test_combine_tfr_error_catch(average_tfr): + """Test combine_tfr() catches errors.""" + # check unrecognised weights string caught + with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): + combine_tfr([average_tfr, average_tfr], weights="foo") + # check bad weights size caught + with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): + combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) + # check different channel names caught + state = average_tfr.__getstate__() + new_info = average_tfr.info.copy() + average_tfr_bad = AverageTFR( + inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) + ) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_tfr([average_tfr, average_tfr_bad]) + # check different times caught + average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) + with pytest.raises( + AssertionError, match=".* do not contain the same time instants" + ): + combine_tfr([average_tfr, average_tfr_bad]) + # check taper dim caught + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs + state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) + state["weights"] = weights + state["dims"] = ("channel", "taper", "freq", "time") + average_tfr_taper = AverageTFR(inst=state) + with pytest.raises( + NotImplementedError, + match="Aggregating multitaper tapers across TFR datasets is not supported.", + ): + combine_tfr([average_tfr_taper, average_tfr_taper]) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 12d45d5d572..fc60802f61b 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -478,7 +486,9 @@ def _compute_tfr( * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. - + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +505,9 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +529,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -529,15 +545,18 @@ def _compute_tfr( if method == "morlet": W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) Ws = [W] # to have same dimensionality as the 'multitaper' case + weights = None # no tapers for Morlet estimates elif method == "multitaper": - Ws = _make_dpss( + Ws, weights = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=True, # required for converting complex → power ) + weights = np.asarray(weights) # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -560,7 +579,7 @@ def _compute_tfr( if ("avg_" in output) or ("itc" in output): out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": - out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -571,7 +590,7 @@ def _compute_tfr( # Parallelization is applied across channels. tfrs = parallel( - my_cwt(channel, Ws, output, use_fft, "same", decim, method) + my_cwt(channel, Ws, output, use_fft, "same", decim, weights) for channel in epoch_data.transpose(1, 0, 2) ) @@ -581,10 +600,10 @@ def _compute_tfr( if ("avg_" not in output) and ("itc" not in output): # This is to enforce that the first dimension is for epochs - if output in ["complex", "phase"] and method == "multitaper": - out = out.transpose(2, 0, 1, 3, 4) - else: - out = out.transpose(1, 0, 2, 3) + out = np.moveaxis(out, 1, 0) + + if return_weights: + return out, weights return out @@ -598,8 +617,7 @@ def _check_tfr_param( freqs = np.asarray(freqs, dtype=float) if freqs.ndim != 1: raise ValueError( - f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} " - "instead." + f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} instead." ) # Check sfreq @@ -658,7 +676,7 @@ def _check_tfr_param( return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim -def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): +def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None): """Aux. function to _compute_tfr. Loops time-frequency transform across wavelets and epochs. @@ -685,9 +703,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): See numpy.convolve. decim : slice The decimation slice: e.g. power[:, decim] - method : str | None - Used only for multitapering to create tapers dimension in the output - if ``output in ['complex', 'phase']``. + weights : array, shape (n_tapers, n_wavelets) | None + Concentration weights for each taper in the wavelets, if present. """ # Set output type dtype = np.float64 @@ -701,10 +718,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): n_freqs = len(Ws[0]) if ("avg_" in output) or ("itc" in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) - elif output in ["complex", "phase"] and method == "multitaper": - tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) + elif output in ["complex", "phase"] and weights is not None: + tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype) else: tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) + if weights is not None: + weights = np.expand_dims(weights, axis=-1) # add singleton time dimension # Loops across tapers. for taper_idx, W in enumerate(Ws): @@ -719,6 +738,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Loop across epochs for epoch_idx, tfr in enumerate(coefs): # Transform complex values + if output not in ["complex", "phase"] and weights is not None: + tfr = weights[taper_idx] * tfr # weight each taper estimate if output in ["power", "avg_power"]: tfr = (tfr * tfr.conj()).real # power elif output == "phase": @@ -734,8 +755,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Stack or add if ("avg_" in output) or ("itc" in output): tfrs += tfr - elif output in ["complex", "phase"] and method == "multitaper": - tfrs[taper_idx, epoch_idx] += tfr + elif output in ["complex", "phase"] and weights is not None: + tfrs[epoch_idx, taper_idx] += tfr else: tfrs[epoch_idx] += tfr @@ -749,9 +770,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): if ("avg_" in output) or ("itc" in output): tfrs /= n_epochs - # Normalization by number of taper - if n_tapers > 1 and output not in ["complex", "phase"]: - tfrs /= n_tapers + # Normalization by taper weights + if n_tapers > 1 and output not in ["complex", "phase", "itc"]: + if "avg_" not in output: # add singleton epochs dimension to weights + weights = np.expand_dims(weights, axis=0) + tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3) + if output == "avg_power_itc": # weight itc by the number of tapers + tfrs.imag = tfrs.imag / n_tapers + return tfrs @@ -1184,8 +1210,8 @@ def __init__( classname = "EpochsTFR" # end TODO raise ValueError( - f'{classname} got unsupported parameter value{_pl(problem)} ' - f'{" and ".join(problem)}.' + f"{classname} got unsupported parameter value{_pl(problem)} " + f"{' and '.join(problem)}." ) # check method valid_methods = ["morlet", "multitaper"] @@ -1200,6 +1226,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1221,6 +1250,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1379,6 +1409,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1389,7 +1420,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1407,12 +1437,13 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1465,18 +1496,29 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) + if n_taper is not None and self._weights is None: + raise ValueError("Taper dimension in data, but no weights found.") if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) + elif n_taper is not None and n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_taper is not None and n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) else: return raise ValueError(msg) @@ -1496,7 +1538,7 @@ def _check_values(self, negative_ok=False): s = _pl(negative_values.sum()) warn( f"Negative value in time-frequency decomposition for channel{s} " - f'{", ".join(chs)}', + f"{', '.join(chs)}", UserWarning, ) @@ -1513,6 +1555,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1613,6 +1659,7 @@ def _onselect( fmax=fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) # average over times and freqs @@ -1691,6 +1738,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. @@ -1785,6 +1837,7 @@ def get_data( tmax=None, return_times=False, return_freqs=False, + return_tapers=False, ): """Get time-frequency data in NumPy array format. @@ -1800,6 +1853,10 @@ def get_data( return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. + return_tapers : bool + Whether to return the taper numbers. Default is ``False``. + + .. versionadded:: 1.10.0 Returns ------- @@ -1811,6 +1868,9 @@ def get_data( freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. + tapers : array | None + The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be + ``None`` if a taper dimension is not present in the data. Notes ----- @@ -1848,7 +1908,13 @@ def get_data( if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) - if not return_times and not return_freqs: + if return_tapers: + if "taper" in self._dims: + tapers = np.arange(self.shape[self._dims.index("taper")]) + else: + tapers = None + out.append(tapers) + if not return_times and not return_freqs and not return_tapers: return out[0] return tuple(out) @@ -1960,6 +2026,7 @@ def plot( baseline=baseline, mode=mode, dB=dB, + taper_weights=self.weights, verbose=verbose, ) # shape @@ -1970,6 +2037,9 @@ def plot( want_shape[ch_axis] = len(idx_picks) if combine is None else 1 want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = [ + n for dim, n in zip(self._dims, want_shape) if dim != "taper" + ] # tapers must be aggregated over by now want_shape = tuple(want_shape) # combine combine_was_none = combine is None @@ -2313,6 +2383,7 @@ def plot_joint( fmax=_fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) _data = _data.mean(axis=(-1, -2)) # avg over times and freqs @@ -2461,23 +2532,23 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks - # TODO this is the only remaining call to _preproc_tfr; should be refactored - # (to use _prep_data_for_plot?) - data, times, freqs, vmin, vmax = _preproc_tfr( + # baseline, crop, convert complex to power, aggregate tapers, and dB scaling + data, times, freqs = _prep_data_for_plot( data, times, freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - info["sfreq"], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + taper_weights=self.weights, + verbose=verbose, ) + # get vlims + vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) if layout is None: from mne import find_layout @@ -2624,21 +2695,21 @@ def to_data_frame( ): """Export data in tabular structure as a pandas DataFrame. - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. + Channels are converted to columns in the DataFrame. By default, additional + columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` + (epoch event description) are added, unless ``index`` is not ``None`` (in which + case the columns specified in ``index`` will be used to form the DataFrame's + index instead). ``'epoch'``, and ``'condition'`` are not supported for + ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is + present, such as for complex or phase multitaper data. Parameters ---------- %(picks_all)s %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. + Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, + and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and + ``'taper'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s @@ -2651,42 +2722,58 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) + if unagg_mt: + valid_index_args.append("taper") valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) - data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + data, times, freqs, tapers = self.get_data( + picks, return_times=True, return_freqs=True, return_tapers=True + ) + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + if unagg_mt: + tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) + mindex.append(("taper", tapers)) + default_index.append("taper") + default_index.extend(["freq", "time"]) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) @@ -2733,6 +2820,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2849,6 +2937,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2892,6 +2989,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2904,6 +3002,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2914,12 +3013,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -2962,6 +3071,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3041,8 +3151,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3152,7 +3269,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3524,6 +3650,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3540,6 +3667,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3562,6 +3690,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3572,6 +3701,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3614,6 +3744,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3663,6 +3794,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3728,6 +3872,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3738,6 +3883,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3755,20 +3901,23 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) def combine_tfr(all_tfr, weights="nave"): """Merge AverageTFR data by weighted addition. - Create a new AverageTFR instance, using a combination of the supplied - instances as its data. By default, the mean (weighted by trials) is used. - Subtraction can be performed by passing negative weights (e.g., [1, -1]). - Data must have the same channels and the same time instants. + Create a new :class:`mne.time_frequency.AverageTFR` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., [1, -1]). Data + must have the same channels and the same time instants. Parameters ---------- @@ -3786,8 +3935,16 @@ def combine_tfr(all_tfr, weights="nave"): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ + if any("taper" in tfr._dims for tfr in all_tfr): + raise NotImplementedError( + "Aggregating multitaper tapers across TFR datasets is not supported." + ) + tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): @@ -3803,10 +3960,10 @@ def combine_tfr(all_tfr, weights="nave"): ch_names = tfr.ch_names for t_ in all_tfr[1:]: - assert t_.ch_names == ch_names, ValueError( + assert t_.ch_names == ch_names, ( f"{tfr} and {t_} do not contain the same channels" ) - assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError( + assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ( f"{tfr} and {t_} do not contain the same time instants" ) @@ -3861,62 +4018,6 @@ def _centered(arr, newsize): return arr[tuple(myslice)] -def _preproc_tfr( - data, - times, - freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - sfreq, - copy=None, -): - """Aux Function to prepare tfr computation.""" - if copy is None: - copy = baseline is not None - data = rescale(data, times, baseline, mode, copy=copy) - - if np.iscomplexobj(data): - # complex amplitude → real power (for plotting); if data are - # real-valued they should already be power - data = (data * data.conj()).real - - # crop time - itmin, itmax = None, None - idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0] - if tmin is not None: - itmin = idx[0] - if tmax is not None: - itmax = idx[-1] + 1 - - times = times[itmin:itmax] - - # crop freqs - ifmin, ifmax = None, None - idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0] - if fmin is not None: - ifmin = idx[0] - if fmax is not None: - ifmax = idx[-1] + 1 - - freqs = freqs[ifmin:ifmax] - - # crop data - data = data[:, ifmin:ifmax, itmin:itmax] - - if dB: - data = 10 * np.log10(data) - - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) - return data, times, freqs, vmin, vmax - - def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") @@ -4061,7 +4162,7 @@ def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): if len(out) == 0: raise ValueError( f'Cannot find condition "{condition}" in this file. ' - f'The file contains conditions {", ".join(keys)}' + f"The file contains conditions {', '.join(keys)}" ) if len(out) == 1: out = out[0] @@ -4151,6 +4252,7 @@ def _prep_data_for_plot( baseline=None, mode=None, dB=False, + taper_weights=None, verbose=None, ): # baseline @@ -4164,9 +4266,43 @@ def _prep_data_for_plot( freqs = freqs[freq_mask] # crop data data = data[..., freq_mask, :][..., time_mask] - # complex amplitude → real power; real-valued data is already power (or ITC) + # handle unaggregated multitaper (complex or phase multitaper data) + if taper_weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + if np.iscomplexobj(data): # complex coefficients → power + data = _tfr_from_mt(data, taper_weights) + else: # tapered phase data → weighted phase data + # channels, tapers, freqs, time + assert data.ndim == 4 + # weights as a function of (tapers, freqs) + assert taper_weights.ndim == 2 + data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = (data * data.conj()).real if dB: data = 10 * np.log10(data) return data, times, freqs + + +def _tfr_from_mt(x_mt, weights): + """Aggregate complex multitaper coefficients over tapers and convert to power. + + Parameters + ---------- + x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + The complex-valued multitaper coefficients. + weights : array, shape (n_tapers, n_freqs) + The weights to use to combine the tapered estimates. + + Returns + ------- + tfr : array, shape (n_channels, n_freqs, n_times) + The time-frequency power estimates. + """ + weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + tfr = weights * x_mt + tfr *= tfr.conj() + tfr = tfr.real.sum(axis=1) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + return tfr diff --git a/mne/transforms.py b/mne/transforms.py index c85c31964b6..7072ea25124 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -12,14 +12,13 @@ import numpy as np from scipy import linalg from scipy.spatial.distance import cdist -from scipy.special import sph_harm from ._fiff.constants import FIFF from ._fiff.open import fiff_open from ._fiff.tag import read_tag from ._fiff.write import start_and_end_file, write_coord_trans from .defaults import _handle_default -from .fixes import _get_img_fdata, jit +from .fixes import _get_img_fdata, jit, sph_harm_y from .utils import ( _check_fname, _check_option, @@ -926,7 +925,7 @@ def _compute_sph_harm(order, az, pol): # _deg_ord_idx(0, 0) = -1 so we're actually okay to use it here for degree in range(order + 1): for order_ in range(degree + 1): - sph = sph_harm(order_, degree, az, pol) + sph = sph_harm_y(degree, order_, pol, az) out[:, _deg_ord_idx(degree, order_)] = _sh_complex_to_real(sph, order_) if order_ > 0: out[:, _deg_ord_idx(degree, -order_)] = _sh_complex_to_real( diff --git a/mne/utils/_logging.py b/mne/utils/_logging.py index 68963feaf61..f4d19655bbf 100644 --- a/mne/utils/_logging.py +++ b/mne/utils/_logging.py @@ -511,7 +511,7 @@ def _frame_info(n): except KeyError: # in our verbose dec pass else: - infos.append(f'{name.lstrip("mne.")}:{frame.f_lineno}') + infos.append(f"{name.lstrip('mne.')}:{frame.f_lineno}") frame = frame.f_back if frame is None: break diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 323b530a641..63e0d1036b9 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -179,9 +179,9 @@ def assert_and_remove_boundary_annot(annotations, n=1): annotations.delete(idx) -def assert_object_equal(a, b, *, err_msg="Object mismatch"): +def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False): """Assert two objects are equal.""" - d = object_diff(a, b) + d = object_diff(a, b, allclose=allclose) assert d == "", f"{err_msg}\n{d}" diff --git a/mne/utils/check.py b/mne/utils/check.py index dc337574c3f..d318cd03a04 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -317,8 +317,7 @@ def _check_subject( _validate_type(second, "str", "subject input") if first is not None and first != second: raise ValueError( - f"{first_kind} ({repr(first)}) did not match " - f"{second_kind} ({second})" + f"{first_kind} ({repr(first)}) did not match {second_kind} ({second})" ) return second elif first is not None: @@ -1076,8 +1075,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): del ch_pos["FPz"] elif "Fpz" not in ch_pos and "Oz" in ch_pos: logger.info( - "Approximating Fpz location by mirroring Oz along " - "the X and Y axes." + "Approximating Fpz location by mirroring Oz along the X and Y axes." ) # This assumes Fpz and Oz have the same Z coordinate ch_pos["Fpz"] = ch_pos["Oz"] * [-1, -1, 1] @@ -1087,7 +1085,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): msg = ( f'sphere="eeglab" requires digitization points of ' f"the following electrode locations in the data: " - f'{", ".join(horizon_ch_names)}, but could not find: ' + f"{', '.join(horizon_ch_names)}, but could not find: " f"{ch_name}" ) if ch_name == "Fpz": @@ -1268,8 +1266,7 @@ def _to_rgb(*args, name="color", alpha=False): except ValueError: args = args[0] if len(args) == 1 else args raise ValueError( - f'Invalid RGB{"A" if alpha else ""} argument(s) for {name}: ' - f"{repr(args)}" + f"Invalid RGB{'A' if alpha else ''} argument(s) for {name}: {repr(args)}" ) from None @@ -1293,5 +1290,5 @@ def _check_method_kwargs(func, kwargs, msg=None): if msg is None: msg = f'function "{func}"' raise TypeError( - f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} for {msg}.' + f"Got unexpected keyword argument{s} {', '.join(invalid_kw)} for {msg}." ) diff --git a/mne/utils/config.py b/mne/utils/config.py index a817886c3f0..c28373fcb93 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -185,8 +185,7 @@ def set_memmap_min_size(memmap_min_size): "triggers automated memory mapping, e.g., 1M or 0.5G" ), "MNE_REPR_HTML": ( - "bool, represent some of our objects with rich HTML in a notebook " - "environment" + "bool, represent some of our objects with rich HTML in a notebook environment" ), "MNE_SKIP_NETWORK_TESTS": ( "bool, used in a test decorator (@requires_good_network) to skip " @@ -203,8 +202,7 @@ def set_memmap_min_size(memmap_min_size): ), "MNE_USE_CUDA": "bool, use GPU for filtering/resampling", "MNE_USE_NUMBA": ( - "bool, use Numba just-in-time compiler for some of our intensive " - "computations" + "bool, use Numba just-in-time compiler for some of our intensive computations" ), "SUBJECTS_DIR": "path-like, directory of freesurfer MRI files for each subject", } @@ -583,9 +581,9 @@ def _get_numpy_libs(): for pool in pools: if pool["internal_api"] in ("openblas", "mkl"): return ( - f'{rename[pool["internal_api"]]} ' - f'{pool["version"]} with ' - f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}' + f"{rename[pool['internal_api']]} " + f"{pool['version']} with " + f"{pool['num_threads']} thread{_pl(pool['num_threads'])}" ) return bad_lib @@ -874,7 +872,7 @@ def sys_info( pre = "│ " else: pre = " | " - out(f'\n{pre}{" " * ljust}{op.dirname(mod.__file__)}') + out(f"\n{pre}{' ' * ljust}{op.dirname(mod.__file__)}") out("\n") if not mne_version_good: diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 60e02432c7b..54cc6845e58 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_fmt_support_epochs"] = """\ Supported formats: - - EEGLAB (``.set``, uses :mod:`eeglabio`) + +- EEGLAB (``.set``, uses :mod:`eeglabio`) """ docdict["export_fmt_support_evoked"] = """\ Supported formats: - - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) + +- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ docdict["export_fmt_support_raw"] = """\ Supported formats: - - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `edfio `_) + +- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) +- EEGLAB (``.set``, uses :mod:`eeglabio`) +- EDF (``.edf``, uses `edfio `_) """ # noqa: E501 docdict["export_warning"] = """\ @@ -4656,6 +4659,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The title of the generated figure. If ``None`` (default), no title is displayed. """ + +docdict["title_stc"] = """ +title : str | None + Title for the figure window. If ``None``, the subject name will be used. +""" + docdict["title_tfr_plot"] = """ title : str | 'auto' | None Title for the plot. If ``"auto"``, will use the channel name (if ``combine`` is @@ -5008,6 +5017,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weights_tfr_array"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. + + .. versionadded:: 1.10.0 +""" +docdict["weights_tfr_attr"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights used for each taper in the time-frequency estimates. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. diff --git a/mne/utils/misc.py b/mne/utils/misc.py index bb3e3ee5cab..343761aee24 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -379,7 +379,7 @@ def _assert_no_instances(cls, when=""): check = False if check: if cls.__name__ == "Brain": - ref.append(f'Brain._cleaned = {getattr(obj, "_cleaned", None)}') + ref.append(f"Brain._cleaned = {getattr(obj, '_cleaned', None)}") rr = gc.get_referrers(obj) count = 0 for r in rr: diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..5029e8fbeca 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -515,55 +515,65 @@ def _freq_mask(freqs, sfreq, fmin=None, fmax=None, raise_error=True): def grand_average(all_inst, interpolate_bads=True, drop_bads=True): - """Make grand average of a list of Evoked or AverageTFR data. + """Make grand average of a list of Evoked, AverageTFR, or Spectrum data. - For :class:`mne.Evoked` data, the function interpolates bad channels based - on the ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, - the grand average file will contain good channels and the bad channels - interpolated from the good MEG/EEG channels. - For :class:`mne.time_frequency.AverageTFR` data, the function takes the - subset of channels not marked as bad in any of the instances. + For :class:`mne.Evoked` data, the function interpolates bad channels based on the + ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, the grand average + file will contain good channels and the bad channels interpolated from the good + MEG/EEG channels. + For :class:`mne.time_frequency.AverageTFR` and :class:`mne.time_frequency.Spectrum` + data, the function takes the subset of channels not marked as bad in any of the + instances. - The ``grand_average.nave`` attribute will be equal to the number - of evoked datasets used to calculate the grand average. + The ``grand_average.nave`` attribute will be equal to the number of datasets used to + calculate the grand average. - .. note:: A grand average evoked should not be used for source - localization. + .. note:: A grand average evoked should not be used for source localization. Parameters ---------- - all_inst : list of Evoked or AverageTFR - The evoked datasets. + all_inst : list of Evoked, AverageTFR or Spectrum + The datasets. + + .. versionchanged:: 1.10.0 + Added support for :class:`~mne.time_frequency.Spectrum` objects. + interpolate_bads : bool If True, bad MEG and EEG channels are interpolated. Ignored for - AverageTFR. + :class:`~mne.time_frequency.AverageTFR` and + :class:`~mne.time_frequency.Spectrum` data. drop_bads : bool - If True, drop all bad channels marked as bad in any data set. - If neither interpolate_bads nor drop_bads is True, in the output file, - every channel marked as bad in at least one of the input files will be - marked as bad, but no interpolation or dropping will be performed. + If True, drop all bad channels marked as bad in any data set. If neither + ``interpolate_bads`` nor ``drop_bads`` is `True`, in the output file, every + channel marked as bad in at least one of the input files will be marked as bad, + but no interpolation or dropping will be performed. Returns ------- - grand_average : Evoked | AverageTFR + grand_average : Evoked | AverageTFR | Spectrum The grand average data. Same type as input. Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ # check if all elements in the given list are evoked data from ..channels.channels import equalize_channels from ..evoked import Evoked - from ..time_frequency import AverageTFR + from ..time_frequency import AverageTFR, Spectrum if not all_inst: - raise ValueError("Please pass a list of Evoked or AverageTFR objects.") + raise ValueError( + "Please pass a list of Evoked, AverageTFR, or Spectrum objects." + ) elif len(all_inst) == 1: warn("Only a single dataset was passed to mne.grand_average().") inst_type = type(all_inst[0]) - _validate_type(all_inst[0], (Evoked, AverageTFR), "All elements") + _validate_type(all_inst[0], (Evoked, AverageTFR, Spectrum), "All elements") for inst in all_inst: _validate_type(inst, inst_type, "All elements", "of the same type") @@ -578,6 +588,8 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): for inst in all_inst ] from ..evoked import combine_evoked as combine + elif isinstance(all_inst[0], Spectrum): + from ..time_frequency.spectrum import combine_spectrum as combine else: # isinstance(all_inst[0], AverageTFR): from ..time_frequency.tfr import combine_tfr as combine @@ -588,9 +600,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): inst.drop_channels(bads) equalize_channels(all_inst, copy=False) - # make grand_average object using combine_[evoked/tfr] + # make grand_average object using combine_[evoked/tfr/spectrum] grand_average = combine(all_inst, weights="equal") - # change the grand_average.nave to the number of Evokeds + # change the grand_average.nave to the number of datasets grand_average.nave = len(all_inst) # change comment field grand_average.comment = f"Grand average (n = {grand_average.nave})" @@ -859,6 +871,9 @@ def fit_transform(self, X, y=None): return U + def fit(self, X): + self._fit(X) + def _fit(self, X): if self.n_components is None: n_components = min(X.shape) diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 73f0065a58d..9c558d32a51 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -2354,6 +2354,7 @@ def plot_source_estimates( transparent=True, alpha=1.0, time_viewer="auto", + *, subjects_dir=None, figure=None, views="auto", @@ -2463,8 +2464,7 @@ def plot_source_estimates( Defaults to 'oct6'. .. versionadded:: 0.15.0 - title : str | None - Title for the figure. If None, the subject name will be used. + %(title_stc)s .. versionadded:: 0.17.0 %(show_traces)s @@ -2543,6 +2543,7 @@ def plot_source_estimates( view_layout=view_layout, add_data_kwargs=add_data_kwargs, brain_kwargs=brain_kwargs, + title=title, **kwargs, ) @@ -2578,6 +2579,7 @@ def _plot_stc( view_layout, add_data_kwargs, brain_kwargs, + title, ): from ..source_estimate import _BaseVolSourceEstimate from .backends.renderer import _get_3d_backend, get_brain_class @@ -2620,7 +2622,9 @@ def _plot_stc( if overlay_alpha == 0: smoothing_steps = 1 # Disable smoothing to save time. - title = subject if len(hemis) > 1 else f"{subject} - {hemis[0]}" + sub_info = subject if len(hemis) > 1 else f"{subject} - {hemis[0]}" + title = title if title is not None else sub_info + kwargs = { "subject": subject, "hemi": hemi, @@ -3251,6 +3255,7 @@ def plot_vector_source_estimates( vector_alpha=1.0, scale_factor=None, time_viewer="auto", + *, subjects_dir=None, figure=None, views="lateral", @@ -3262,6 +3267,7 @@ def plot_vector_source_estimates( foreground=None, initial_time=None, time_unit="s", + title=None, show_traces="auto", src=None, volume_options=1.0, @@ -3339,6 +3345,9 @@ def plot_vector_source_estimates( time_unit : 's' | 'ms' Whether time is represented in seconds ("s", default) or milliseconds ("ms"). + %(title_stc)s + + .. versionadded:: 1.9 %(show_traces)s %(src_volume_options)s %(view_layout)s @@ -3385,6 +3394,7 @@ def plot_vector_source_estimates( cortex=cortex, foreground=foreground, size=size, + title=title, scale_factor=scale_factor, show_traces=show_traces, src=src, diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 247c0840858..778700c99a7 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -4072,28 +4072,28 @@ def _update_monotonic(lims, fmin, fmid, fmax): if fmin is not None: lims["fmin"] = fmin if lims["fmax"] < fmin: - logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmin}') + logger.debug(f" Bumping fmax = {lims['fmax']} to {fmin}") lims["fmax"] = fmin if lims["fmid"] < fmin: - logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmin}') + logger.debug(f" Bumping fmid = {lims['fmid']} to {fmin}") lims["fmid"] = fmin assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmid is not None: lims["fmid"] = fmid if lims["fmin"] > fmid: - logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmid}') + logger.debug(f" Bumping fmin = {lims['fmin']} to {fmid}") lims["fmin"] = fmid if lims["fmax"] < fmid: - logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmid}') + logger.debug(f" Bumping fmax = {lims['fmax']} to {fmid}") lims["fmax"] = fmid assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmax is not None: lims["fmax"] = fmax if lims["fmin"] > fmax: - logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmax}') + logger.debug(f" Bumping fmin = {lims['fmin']} to {fmax}") lims["fmin"] = fmax if lims["fmid"] > fmax: - logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmax}') + logger.debug(f" Bumping fmid = {lims['fmid']} to {fmax}") lims["fmid"] = fmax assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 2a1c943250b..5d092c21713 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -850,14 +850,6 @@ def tiny(tmp_path): def test_brain_screenshot(renderer_interactive_pyvistaqt, tmp_path, brain_gc): """Test time viewer screenshot.""" # This is broken on Conda + GHA for some reason - from qtpy import API_NAME - - if ( - os.getenv("CONDA_PREFIX", "") != "" - and os.getenv("GITHUB_ACTIONS", "") == "true" - or API_NAME.lower() == "pyside6" - ): - pytest.skip("Test is unreliable on GitHub Actions conda runs and pyside6") tiny_brain, ratio = tiny(tmp_path) img_nv = tiny_brain.screenshot(time_viewer=False) want = (_TINY_SIZE[1] * ratio, _TINY_SIZE[0] * ratio, 3) @@ -875,9 +867,9 @@ def _assert_brain_range(brain, rng): for key, mesh in layerer._overlays.items(): if key == "curv": continue - assert ( - mesh._rng == rng - ), f"_layered_meshes[{repr(hemi)}][{repr(key)}]._rng != {rng}" + assert mesh._rng == rng, ( + f"_layered_meshes[{repr(hemi)}][{repr(key)}]._rng != {rng}" + ) @testing.requires_testing_data @@ -1245,9 +1237,9 @@ def test_brain_scraper(renderer_interactive_pyvistaqt, brain_gc, tmp_path): w = img.shape[1] w0 = size[0] # On Linux+conda we get a width of 624, similar tweak in test_brain_init above - assert np.isclose(w, w0, atol=30) or np.isclose( - w, w0 * 2, atol=30 - ), f"w ∉ {{{w0}, {2 * w0}}}" # HiDPI + assert np.isclose(w, w0, atol=30) or np.isclose(w, w0 * 2, atol=30), ( + f"w ∉ {{{w0}, {2 * w0}}}" + ) # HiDPI @testing.requires_testing_data diff --git a/mne/viz/_proj.py b/mne/viz/_proj.py index 5d21afb0594..6e0cb9a4143 100644 --- a/mne/viz/_proj.py +++ b/mne/viz/_proj.py @@ -90,8 +90,7 @@ def plot_projs_joint( missing = (~used.astype(bool)).sum() if missing: warn( - f"{missing} projector{_pl(missing)} had no channel names " - "present in epochs" + f"{missing} projector{_pl(missing)} had no channel names present in epochs" ) del projs ch_types = list(proj_by_type) # reduce to number we actually need diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index c415d83e456..467f5cb15e7 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -317,8 +317,7 @@ def _qt_get_stylesheet(theme): file = open(theme) except OSError: warn( - "Requested theme file not found, will use light instead: " - f"{repr(theme)}" + f"Requested theme file not found, will use light instead: {repr(theme)}" ) else: with file as fid: diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 10ec5459e02..b047de4ea32 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -27,6 +27,7 @@ _clean_names, _is_numeric, _pl, + _time_mask, _to_rgb, _validate_type, fill_doc, @@ -1988,10 +1989,18 @@ def plot_evoked_joint( contours = topomap_args.get("contours", 6) ch_type = ch_types.pop() # set should only contain one element # Since the data has all the ch_types, we get the limits from the plot. - vmin, vmax = ts_ax.get_ylim() + vmin, vmax = (None, None) norm = ch_type == "grad" vmin = 0 if norm else vmin - vmin, vmax = _setup_vmin_vmax(evoked.data, vmin, vmax, norm) + time_idx = [ + np.where( + _time_mask(evoked.times, tmin=t, tmax=None, sfreq=evoked.info["sfreq"]) + )[0][0] + for t in times_sec + ] + scalings = topomap_args["scalings"] if "scalings" in topomap_args else None + scaling = _handle_default("scalings", scalings)[ch_type] + vmin, vmax = _setup_vmin_vmax(evoked.data[:, time_idx] * scaling, vmin, vmax, norm) if not isinstance(contours, list | np.ndarray): locator, contours = _set_contour_locator(vmin, vmax, contours) else: diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index 2a93febba4e..cf5a9996216 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -380,45 +380,36 @@ def _configure_dock(self): # Fieldline configuration layout = r._dock_add_group_box("Fieldlines") - if self._show_density: - r._dock_add_label(value="max value", align=True, layout=layout) - - @_auto_weakref - def _callback(vmax, kind, scaling): - self.set_vmax(vmax / scaling, kind=kind) - - for surf_map in self._surf_maps: - if surf_map["map_kind"] == "meg": - scaling = DEFAULTS["scalings"]["grad"] - else: - scaling = DEFAULTS["scalings"]["eeg"] - rng = [0, np.max(np.abs(surf_map["data"])) * scaling] - hlayout = r._dock_add_layout(vertical=False) - - self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = ( - r._dock_add_slider( - name=surf_map["map_kind"].upper(), - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, kind=surf_map["map_kind"], scaling=scaling - ), - double=True, - layout=hlayout, - ) - ) - self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = ( - r._dock_add_spin_box( - name="", - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, kind=surf_map["map_kind"], scaling=scaling - ), - layout=hlayout, - ) - ) - r._layout_add_widget(layout, hlayout) + r._dock_add_label(value="max value", align=True, layout=layout) + + @_auto_weakref + def _callback(vmax, kind, scaling): + self.set_vmax(vmax / scaling, kind=kind) + + for surf_map in self._surf_maps: + if surf_map["map_kind"] == "meg": + scaling = DEFAULTS["scalings"]["grad"] + else: + scaling = DEFAULTS["scalings"]["eeg"] + rng = [0, np.max(np.abs(surf_map["data"])) * scaling] + hlayout = r._dock_add_layout(vertical=False) + + self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = r._dock_add_slider( + name=surf_map["map_kind"].upper(), + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial(_callback, kind=surf_map["map_kind"], scaling=scaling), + double=True, + layout=hlayout, + ) + self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = r._dock_add_spin_box( + name="", + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial(_callback, kind=surf_map["map_kind"], scaling=scaling), + layout=hlayout, + ) + r._layout_add_widget(layout, hlayout) hlayout = r._dock_add_layout(vertical=False) r._dock_add_label( diff --git a/mne/viz/misc.py b/mne/viz/misc.py index ed2636d3961..c83a4dfe717 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -443,7 +443,7 @@ def _plot_mri_contours( if src[0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise ValueError( "Source space must be in MRI coordinates, got " - f'{_frame_to_str[src[0]["coord_frame"]]}' + f"{_frame_to_str[src[0]['coord_frame']]}" ) for src_ in src: points = src_["rr"][src_["inuse"].astype(bool)] @@ -708,8 +708,7 @@ def plot_bem( src = read_source_spaces(src) elif src is not None and not isinstance(src, SourceSpaces): raise TypeError( - "src needs to be None, path-like or SourceSpaces instance, " - f"not {repr(src)}" + f"src needs to be None, path-like or SourceSpaces instance, not {repr(src)}" ) if len(surfaces) == 0: diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 6f109b9490b..34022d59768 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -893,7 +893,7 @@ def test_plot_alignment_fnirs(renderer, tmp_path): with catch_logging() as log: fig = plot_alignment(info, **kwargs) log = log.getvalue() - assert f'fnirs_cw_amplitude: {info["nchan"]}' in log + assert f"fnirs_cw_amplitude: {info['nchan']}" in log _assert_n_actors(fig, renderer, info["nchan"]) fig = plot_alignment(info, fnirs=["channels", "sources", "detectors"], **kwargs) diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index 035008dd87f..964acae2b31 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -339,8 +339,10 @@ def test_plot_evoked_image(): ch_names = evoked.ch_names[3:5] picks = [evoked.ch_names.index(ch) for ch in ch_names] - evoked.plot_image(show_names="all", time_unit="s", picks=picks) - yticklabels = plt.gca().get_yticklabels() + fig = evoked.plot_image(show_names="all", time_unit="s", picks=picks) + fig.canvas.draw_idle() + yticklabels = fig.axes[0].get_yticklabels() + assert len(yticklabels) == len(ch_names) for tick_target, tick_observed in zip(ch_names, yticklabels): assert tick_target in str(tick_observed) evoked.plot_image(show_names=True, time_unit="s") diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index f6b16fe27ad..89e0a7c543d 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -862,16 +862,14 @@ def test_remove_annotations(raw, hide_which, browser_backend): assert len(raw.annotations) == len(hide_which) -def test_merge_annotations(raw, browser_backend): +def test_merge_annotations(raw, pg_backend): """Test merging of annotations in the Qt backend. Let's not bother in figuring out on which sample the _fake_click actually dropped the annotation, especially with the 600.614 Hz weird sampling rate. -> atol = 10 / raw.info["sfreq"] """ - if browser_backend.name == "matplotlib": - pytest.skip("The MPL backend does not support draggable annotations.") - elif not check_version("mne_qt_browser", "0.5.3"): + if not check_version("mne_qt_browser", "0.5.3"): pytest.xfail("mne_qt_browser < 0.5.3 does not merge annotations properly") annot = Annotations( onset=[1, 3, 4, 5, 7, 8], diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index afa9341c00e..b87d0d39f89 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ compute_bridged_electrodes, compute_current_source_density, ) -from mne.time_frequency.tfr import AverageTFRArray +from mne.time_frequency.tfr import AverageTFR, AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -610,6 +610,29 @@ def test_plot_tfr_topomap(): ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) + # test data with taper dimension (real) + data = np.expand_dims(data, axis=1) + weights = np.random.rand(1, n_freqs) + tfr = AverageTFRArray( + info=info, + data=data, + times=times, + freqs=np.arange(n_freqs), + nave=nave, + weights=weights, + ) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # test data with taper dimension (complex) + state = tfr.__getstate__() + tfr = AverageTFR(inst=state | dict(data=data * (1 + 1j))) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # remove taper dim before proceeding + data = data[:, 0] + # test real numbers tfr = AverageTFRArray( info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index dd63a626683..bb180a3f299 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -910,8 +910,7 @@ def _get_pos_outlines(info, picks, sphere, to_sphere=True): orig_sphere = sphere sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) logger.debug( - "Generating pos outlines with sphere " - f"{sphere} from {orig_sphere} for {ch_type}" + f"Generating pos outlines with sphere {sphere} from {orig_sphere} for {ch_type}" ) pos = _find_topomap_coords( info, picks, ignore_overlap=True, to_sphere=to_sphere, sphere=sphere @@ -1262,7 +1261,7 @@ def _plot_topomap( if len(data) != len(pos): raise ValueError( "Data and pos need to be of same length. Got data of " - f"length {len(data)}, pos of length { len(pos)}" + f"length {len(data)}, pos of length {len(pos)}" ) norm = min(data) >= 0 @@ -1409,8 +1408,7 @@ def _plot_ica_topomap( sphere = _check_sphere(sphere, ica.info) if not isinstance(axes, Axes): raise ValueError( - "axis has to be an instance of matplotlib Axes, " - f"got {type(axes)} instead." + f"axis has to be an instance of matplotlib Axes, got {type(axes)} instead." ) ch_type = _get_plot_ch_type(ica, ch_type, allow_ref_meg=ica.allow_ref_meg) if ch_type == "ref_meg": @@ -1882,7 +1880,7 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] + data = tfr.data[picks] # merging grads before rescaling makes ERDs visible if merge_channels: @@ -1890,6 +1888,18 @@ def plot_tfr_topomap( data = rescale(data, tfr.times, baseline, mode, copy=True) + # handle unaggregated multitaper (complex or phase multitaper data) + if tfr.weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims + data = weights * data + if np.iscomplexobj(data): # complex coefficients → power + data *= data.conj() + data = data.real.sum(axis=1) + data *= 2 / (weights * weights.conj()).real.sum(axis=1) + else: # tapered phase data → weighted phase data + data = data.mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = np.sqrt((data * data.conj()).real) @@ -2104,6 +2114,22 @@ def plot_evoked_topomap( :ref:`gridspec ` interface to adjust the colorbar size yourself. + The defaults for ``contours`` and ``vlim`` are handled as follows: + + * When neither ``vlim`` nor a list of ``contours`` is passed, MNE sets + ``vlim`` at ± the maximum absolute value of the data and then chooses + contours within those bounds. + + * When ``vlim`` but not a list of ``contours`` is passed, MNE chooses + contours to be within the ``vlim``. + + * When a list of ``contours`` but not ``vlim`` is passed, MNE chooses + ``vlim`` to encompass the ``contours`` and the maximum absolute value of the + data. + + * When both a list of ``contours`` and ``vlim`` are passed, MNE uses them + as-is. + When ``time=="interactive"``, the figure will publish and subscribe to the following UI events: @@ -2179,8 +2205,7 @@ def plot_evoked_topomap( space = 1 / (2.0 * evoked.info["sfreq"]) if max(times) > max(evoked.times) + space or min(times) < min(evoked.times) - space: raise ValueError( - f"Times should be between {evoked.times[0]:0.3} and " - f"{evoked.times[-1]:0.3}." + f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}." ) # create axes want_axes = n_times + int(colorbar) @@ -2287,11 +2312,17 @@ def plot_evoked_topomap( _vlim = [ _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times) ] - _vlim = (np.min(_vlim), np.max(_vlim)) + _vlim = [np.min(_vlim), np.max(_vlim)] cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0) # set up contours if not isinstance(contours, list | np.ndarray): _, contours = _set_contour_locator(*_vlim, contours) + else: + if vlim[0] is None and np.any(contours < _vlim[0]): + _vlim[0] = contours[0] + if vlim[1] is None and np.any(contours > _vlim[1]): + _vlim[1] = contours[-1] + # prepare for main loop over times kwargs = dict( sensors=sensors, @@ -2779,8 +2810,7 @@ def plot_psds_topomap( # convert legacy list-of-tuple input to a dict bands = {band[-1]: band[:-1] for band in bands} logger.info( - "converting legacy list-of-tuples input to a dict for the " - "`bands` parameter" + "converting legacy list-of-tuples input to a dict for the `bands` parameter" ) # upconvert single freqs to band upper/lower edges as needed bin_spacing = np.diff(freqs)[0] @@ -3340,6 +3370,7 @@ def _set_contour_locator(vmin, vmax, contours): # correct number of bins is equal to contours + 1. locator = ticker.MaxNLocator(nbins=contours + 1) contours = locator.tick_values(vmin, vmax) + contours = contours[1:-1] return locator, contours diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 00458bf3908..a09da17de7d 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -2356,7 +2356,7 @@ def _gfp(data): except KeyError: raise ValueError( f'"combine" must be None, a callable, or one of "{", ".join(valid)}"; ' - f'got {combine}' + f"got {combine}" ) return combine diff --git a/pyproject.toml b/pyproject.toml index 12ad873c3dd..fe3a258dffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ doc = [ "mne-gui-addons", "neo", "numpydoc", + "openneuro-py", "psutil", "pydata_sphinx_theme >= 0.15.2", "pygments >= 2.13", @@ -270,7 +271,6 @@ addopts = """--durations=20 --doctest-modules -rfEXs --cov-report= --tb=short \ --ignore=mne/gui/_*.py --ignore=mne/icons --ignore=tools \ --ignore=mne/report/js_and_css \ --color=yes --capture=sys""" -junit_family = "xunit2" [tool.rstcheck] ignore_directives = [ diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index 56ec04b490d..8880e6478fa 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -9,6 +9,7 @@ if [ "${TEST_MODE}" == "pip" ]; then elif [ "${TEST_MODE}" == "pip-pre" ]; then ${SCRIPT_DIR}/install_pre_requirements.sh python -m pip install $STD_ARGS --pre -e .[test_extra] + echo "##vso[task.setvariable variable=MNE_TEST_ALLOW_SKIP].*(Requires (spm|brainstorm) dataset|Requires MNE-C|CUDA not|Numba not| on Windows|MNE_FORCE_SERIAL|PySide6 causes segfaults).*" else echo "Unknown run type ${TEST_MODE}" exit 1 diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index 2ecc9718ab2..dd3216ebf06 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -11,6 +11,6 @@ python -m pip install --upgrade --progress-bar off \ alphaCSC autoreject bycycle conpy emd fooof meggie \ mne-ari mne-bids-pipeline mne-faster mne-features \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ - neurodsp neurokit2 niseq nitime openneuro-py pactools \ + neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ sleepecg tensorpac yasa meegkit eeg_positions diff --git a/tools/dev/ensure_headers.py b/tools/dev/ensure_headers.py index b5b425b5900..a4095d82b42 100644 --- a/tools/dev/ensure_headers.py +++ b/tools/dev/ensure_headers.py @@ -156,15 +156,15 @@ def _ensure_copyright(lines, path): lines[insert] = COPYRIGHT_LINE else: lines.insert(insert, COPYRIGHT_LINE) - assert ( - lines.count(COPYRIGHT_LINE) == 1 - ), f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + assert lines.count(COPYRIGHT_LINE) == 1, ( + f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + ) def _ensure_blank(lines, path): - assert ( - lines.count(COPYRIGHT_LINE) == 1 - ), f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + assert lines.count(COPYRIGHT_LINE) == 1, ( + f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + ) insert = lines.index(COPYRIGHT_LINE) + 1 if lines[insert].strip(): # actually has content lines.insert(insert, "") diff --git a/tools/get_minimal_commands.sh b/tools/get_minimal_commands.sh index 4e28fdf9e7b..8190f331075 100755 --- a/tools/get_minimal_commands.sh +++ b/tools/get_minimal_commands.sh @@ -11,7 +11,7 @@ export MNE_ROOT="${PWD}/minimal_cmds" export PATH=${MNE_ROOT}/bin:$PATH if [ "${GITHUB_ACTIONS}" == "true" ]; then echo "Setting MNE_ROOT for GHA" - echo "MNE_ROOT=${MNE_ROOT}" >> $GITHUB_ENV; + echo "MNE_ROOT=${MNE_ROOT}" | tee -a $GITHUB_ENV; echo "${MNE_ROOT}/bin" >> $GITHUB_PATH; elif [ "${AZURE_CI}" == "true" ]; then echo "Setting MNE_ROOT for Azure" @@ -33,9 +33,9 @@ if [[ "${CI_OS_NAME}" != "macos"* ]]; then export NEUROMAG2FT_ROOT="${PWD}/minimal_cmds/bin" export FREESURFER_HOME="${MNE_ROOT}" if [ "${GITHUB_ACTIONS}" == "true" ]; then - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" >> "$GITHUB_ENV"; - echo "NEUROMAG2FT_ROOT=${NEUROMAG2FT_ROOT}" >> "$GITHUB_ENV"; - echo "FREESURFER_HOME=${FREESURFER_HOME}" >> "$GITHUB_ENV"; + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "$GITHUB_ENV"; + echo "NEUROMAG2FT_ROOT=${NEUROMAG2FT_ROOT}" | tee -a "$GITHUB_ENV"; + echo "FREESURFER_HOME=${FREESURFER_HOME}" | tee -a "$GITHUB_ENV"; fi; if [ "${AZURE_CI}" == "true" ]; then echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]${LD_LIBRARY_PATH}" @@ -57,7 +57,7 @@ else export DYLD_LIBRARY_PATH=${MNE_ROOT}/lib:$DYLD_LIBRARY_PATH if [ "${GITHUB_ACTIONS}" == "true" ]; then echo "Setting variables for GHA" - echo "DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}" >> "$GITHUB_ENV"; + echo "DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}" | tee -a "$GITHUB_ENV"; set -x wget https://github.com/XQuartz/XQuartz/releases/download/XQuartz-2.7.11/XQuartz-2.7.11.dmg sudo hdiutil attach XQuartz-2.7.11.dmg diff --git a/tools/get_testing_version.sh b/tools/get_testing_version.sh index 44ff28addb4..aaf703dbddd 100755 --- a/tools/get_testing_version.sh +++ b/tools/get_testing_version.sh @@ -6,7 +6,7 @@ TESTING_VERSION=`grep -o "testing=\"[0-9.]\+\"" mne/datasets/config.py | cut -d # This can be incremented to start fresh when the cache misbehaves, e.g.: # TESTING_VERSION=${TESTING_VERSION}-1 if [ ! -z $GITHUB_ENV ]; then - echo "TESTING_VERSION="$TESTING_VERSION >> $GITHUB_ENV + echo "TESTING_VERSION="$TESTING_VERSION | tee -a $GITHUB_ENV elif [ ! -z $AZURE_CI ]; then echo "##vso[task.setvariable variable=testing_version]$TESTING_VERSION" elif [ ! -z $CIRCLECI ]; then diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 149f5a194da..d47d9070f8b 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -23,7 +23,7 @@ if [ ! -z "$CONDA_ENV" ]; then elif [[ "${MNE_CI_KIND}" == "pip" ]]; then # Only used for 3.13 at the moment, just get test deps plus a few extras # that we know are available - INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser" + INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml numba" INSTALL_KIND="test" else test "${MNE_CI_KIND}" == "pip-pre" diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh index 3f9322d14c3..9f424ae5f48 100755 --- a/tools/github_actions_env_vars.sh +++ b/tools/github_actions_env_vars.sh @@ -5,26 +5,30 @@ set -eo pipefail -x if [[ "$MNE_CI_KIND" == "pip"* ]]; then echo "Setting pip env vars for $MNE_CI_KIND" if [[ "$MNE_CI_KIND" == "pip-pre" ]]; then - echo "MNE_QT_BACKEND=PyQt6" >> $GITHUB_ENV + echo "MNE_QT_BACKEND=PyQt6" | tee -a $GITHUB_ENV # We should test an eager import somewhere, might as well be here - echo "EAGER_IMPORT=true" >> $GITHUB_ENV + echo "EAGER_IMPORT=true" | tee -a $GITHUB_ENV + # Make sure nothing unexpected is skipped + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|Numba not|PySide6 causes segfaults).*" | tee -a $GITHUB_ENV else - echo "MNE_QT_BACKEND=PySide6" >> $GITHUB_ENV + echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV fi else # conda-like echo "Setting conda env vars for $MNE_CI_KIND" if [[ "$MNE_CI_KIND" == "old" ]]; then - echo "CONDA_ENV=tools/environment_old.yml" >> $GITHUB_ENV - echo "MNE_IGNORE_WARNINGS_IN_TESTS=true" >> $GITHUB_ENV - echo "MNE_SKIP_NETWORK_TESTS=1" >> $GITHUB_ENV - echo "MNE_QT_BACKEND=PyQt5" >> $GITHUB_ENV + echo "CONDA_ENV=tools/environment_old.yml" | tee -a $GITHUB_ENV + echo "MNE_IGNORE_WARNINGS_IN_TESTS=true" | tee -a $GITHUB_ENV + echo "MNE_SKIP_NETWORK_TESTS=1" | tee -a $GITHUB_ENV + echo "MNE_QT_BACKEND=PyQt5" | tee -a $GITHUB_ENV elif [[ "$MNE_CI_KIND" == "minimal" ]]; then - echo "CONDA_ENV=tools/environment_minimal.yml" >> $GITHUB_ENV - echo "MNE_QT_BACKEND=PySide6" >> $GITHUB_ENV + echo "CONDA_ENV=tools/environment_minimal.yml" | tee -a $GITHUB_ENV + echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV else # conda, mamba (use warning level for completeness) - echo "CONDA_ENV=environment.yml" >> $GITHUB_ENV - echo "MNE_LOGGING_LEVEL=warning" >> $GITHUB_ENV - echo "MNE_QT_BACKEND=PySide6" >> $GITHUB_ENV + echo "CONDA_ENV=environment.yml" | tee -a $GITHUB_ENV + echo "MNE_LOGGING_LEVEL=warning" | tee -a $GITHUB_ENV + echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV + # TODO: Also need "|unreliable on GitHub Actions conda" on macOS, but omit for now to make sure the failure actually shows up + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults|Accelerate|Flakey verbose behavior).*" | tee -a $GITHUB_ENV fi fi set +x diff --git a/tools/github_actions_test.sh b/tools/github_actions_test.sh index 4cdd202223f..4fe8756bd50 100755 --- a/tools/github_actions_test.sh +++ b/tools/github_actions_test.sh @@ -13,19 +13,25 @@ else USE_DIRS="mne/" fi JUNIT_PATH="junit-results.xml" -if [[ ! -z "$CONDA_ENV" ]] && [[ "${RUNNER_OS}" != "Windows" ]]; then - JUNIT_PATH="$(pwd)/${JUNIT_PATH}" +if [[ ! -z "$CONDA_ENV" ]] && [[ "${RUNNER_OS}" != "Windows" ]] && [[ "${MNE_CI_KIND}" != "minimal" ]] && [[ "${MNE_CI_KIND}" != "old" ]]; then + PROJ_PATH="$(pwd)" + JUNIT_PATH="$PROJ_PATH/${JUNIT_PATH}" # Use the installed version after adding all (excluded) test files - cd .. + cd ~ # so that "import mne" doesn't just import the checked-out data INSTALL_PATH=$(python -c "import mne, pathlib; print(str(pathlib.Path(mne.__file__).parents[1]))") - echo "Copying tests from $(pwd)/mne-python/mne/ to ${INSTALL_PATH}/mne/" - echo "::group::rsync" - rsync -a --partial --progress --prune-empty-dirs --exclude="*.pyc" --include="**/" --include="**/tests/*" --include="**/tests/data/**" --exclude="**" ./mne-python/mne/ ${INSTALL_PATH}/mne/ + echo "Copying tests from ${PROJ_PATH}/mne-python/mne/ to ${INSTALL_PATH}/mne/" + echo "::group::rsync mne" + rsync -a --partial --progress --prune-empty-dirs --exclude="*.pyc" --include="**/" --include="**/tests/*" --include="**/tests/data/**" --exclude="**" ${PROJ_PATH}/mne/ ${INSTALL_PATH}/mne/ echo "::endgroup::" + echo "::group::rsync doc" + mkdir -p ${INSTALL_PATH}/doc/ + rsync -a --partial --progress --prune-empty-dirs --include="**/" --include="**/api/*" --exclude="**" ${PROJ_PATH}/doc/ ${INSTALL_PATH}/doc/ + test -f ${INSTALL_PATH}/doc/api/reading_raw_data.rst cd $INSTALL_PATH - echo "Executing from $(pwd)" + cp -av $PROJ_PATH/pyproject.toml . + echo "::endgroup::" fi set -x -pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml --color=yes --junit-xml=$JUNIT_PATH -vv ${USE_DIRS} -set +x +pytest -m "${CONDITION}" --cov=mne --cov-report xml --color=yes --continue-on-collection-errors --junit-xml=$JUNIT_PATH -vv ${USE_DIRS} +echo "Exited with code $?" diff --git a/tools/hooks/update_environment_file.py b/tools/hooks/update_environment_file.py index 8cac6193959..0b5380a16b5 100755 --- a/tools/hooks/update_environment_file.py +++ b/tools/hooks/update_environment_file.py @@ -4,14 +4,14 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import difflib import re from pathlib import Path import tomllib repo_root = Path(__file__).resolve().parents[2] -pyproj = tomllib.loads((repo_root / "pyproject.toml").read_text("utf-8")) +with open(repo_root / "pyproject.toml", "rb") as fid: + pyproj = tomllib.load(fid) # Get our "full" dependences from `pyproject.toml`, but actually ignore the # "full" section as it's just "full-noqt" plus PyQt6, and for conda we need PySide @@ -22,7 +22,7 @@ deps |= set(section_deps) recursive_deps = set(d for d in deps if d.startswith("mne[")) deps -= recursive_deps -deps |= {"pip"} +deps |= {"pip", "mamba", "nomkl"} def remove_spaces(version_spec): @@ -48,11 +48,6 @@ def split_dep(dep): translations = dict(neo="python-neo") pip_deps = set() conda_deps = set() -check_old = ( - "numpy scipy matplotlib pandas scikit-learn nibabel tqdm pooch decorator " - "packaging jinja2 lazy_loader" -).split() -old_deps = [None] * len(check_old) for dep in deps: package_name, version_spec = split_dep(dep) # handle package name differences @@ -61,6 +56,9 @@ def split_dep(dep): # `environment.yaml` breaks the solver if package_name == "PySide6": version_spec = version_spec.replace("!=6.7.0,", "") + elif package_name == "vtk": + # TODO VERSION remove once we support VTK 9.4 + version_spec = "=9.3.1=qt_*" # rstrip output line in case `version_spec` == "" line = f" - {package_name} {version_spec}".rstrip() # use pip for packages needing e.g. `platform_system` or `python_version` triaging @@ -68,12 +66,6 @@ def split_dep(dep): pip_deps.add(f" {line}") else: conda_deps.add(line) - # old deps - if package_name in check_old: - # Pull out >= part, change to =, remove < (which should be after comma) - old_deps[check_old.index(package_name)] = line.replace(">=", "=").split(",")[0] -for di, dep in enumerate(old_deps): - assert dep is not None, f"Missing {check_old[di]}" # TODO: temporary workaround while we wait for a release containing the fix for # https://github.com/mamba-org/mamba/issues/3467 @@ -87,33 +79,14 @@ def split_dep(dep): """ pip_section = pip_section if len(pip_deps) else "" # prepare the env file -header = f"""\ -# THIS FILE IS AUTO-GENERATED BY {'/'.join(Path(__file__).parts[-3:])} AND WILL BE OVERWRITTEN +env = f"""\ +# THIS FILE IS AUTO-GENERATED BY {"/".join(Path(__file__).parts[-3:])} AND WILL BE OVERWRITTEN name: mne channels: - conda-forge -dependencies:""" # noqa: E501 -env = f"""{header} +dependencies: - python {req_python} {newline.join(sorted(conda_deps, key=str.casefold))} -{pip_section}""" - -env_file = repo_root / "environment.yml" -old_env = env_file.read_text("utf-8") -if old_env != env: - diff = "\n".join(difflib.unified_diff(old_env.splitlines(), env.splitlines())) - print(f"Updating {env_file} with diff:\n{diff}") - env_file.write_text(env, encoding="utf-8") - -# Now we also updated tools/environment_old.yml -env_file = repo_root / "tools" / "environment_old.yml" -old_env = env_file.read_text("utf-8") -use_python = req_python.replace(">=", "=") -env = f"""{header} - - python {use_python} -{newline.join(old_deps)} -""" -if old_env != env: - diff = "\n".join(difflib.unified_diff(old_env.splitlines(), env.splitlines())) - print(f"Updating {env_file} with diff:\n{diff}") - env_file.write_text(env, encoding="utf-8") +{pip_section}""" # noqa: E501 + +(repo_root / "environment.yml").write_text(env) diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index 74868b0a435..c717b1b477b 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -18,10 +18,11 @@ python -m pip install $STD_ARGS pip setuptools packaging \ py-cpuinfo blosc2 hatchling echo "NumPy/SciPy/pandas etc." python -m pip uninstall -yq numpy +python -m pip install --upgrade matplotlib # TODO: Until https://github.com/matplotlib/matplotlib/pull/29427 lands python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \ --index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ "numpy>=2.1.0.dev0" "scikit-learn>=1.6.dev0" "scipy>=1.15.0.dev0" \ - "pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \ + "pandas>=3.0.0.dev0" \ "h5py>=3.12.1" "dipy>=1.10.0.dev0" "pyarrow>=19.0.0.dev0" "tables>=3.10.2.dev0" # statsmodels requires formulaic@main so we need to use --extra-index-url @@ -48,7 +49,7 @@ python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https:/ python -c "import vtk" echo "PyVista" -python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista" +python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista" trame trame-vtk trame-vuetify jupyter ipyevents ipympl echo "picard" python -m pip install $STD_ARGS git+https://github.com/pierreablin/picard @@ -57,7 +58,7 @@ echo "pyvistaqt" pip install $STD_ARGS git+https://github.com/pyvista/pyvistaqt echo "imageio-ffmpeg, xlrd, mffpy" -pip install $STD_ARGS imageio-ffmpeg xlrd mffpy traitlets pybv eeglabio +pip install $STD_ARGS imageio-ffmpeg xlrd mffpy traitlets pybv eeglabio defusedxml antio echo "mne-qt-browser" pip install $STD_ARGS git+https://github.com/mne-tools/mne-qt-browser @@ -76,13 +77,11 @@ echo "edfio" # https://github.com/mne-tools/mne-python/pull/12609#issuecomment-2115639369 GIT_CLONE_PROTECTION_ACTIVE=false pip install $STD_ARGS git+https://github.com/the-siesta-group/edfio -if [[ "${PLATFORM}" == "Linux" ]]; then - echo "h5io" - pip install $STD_ARGS git+https://github.com/h5io/h5io +echo "h5io" +pip install $STD_ARGS git+https://github.com/h5io/h5io - echo "pysnirf2" - pip install $STD_ARGS git+https://github.com/BUNPC/pysnirf2 -fi +echo "pysnirf2" +pip install $STD_ARGS git+https://github.com/BUNPC/pysnirf2 # Make sure we're on a NumPy 2.0 variant echo "Checking NumPy version" diff --git a/tutorials/forward/20_source_alignment.py b/tutorials/forward/20_source_alignment.py index dd26f610907..c8cf981dce9 100644 --- a/tutorials/forward/20_source_alignment.py +++ b/tutorials/forward/20_source_alignment.py @@ -115,11 +115,11 @@ mne.viz.set_3d_view(fig, 45, 90, distance=0.6, focalpoint=(0.0, 0.0, 0.0)) print( "Distance from head origin to MEG origin: " - f"{1000 * np.linalg.norm(raw.info["dev_head_t"]["trans"][:3, 3]):.1f} mm" + f"{1000 * np.linalg.norm(raw.info['dev_head_t']['trans'][:3, 3]):.1f} mm" ) print( "Distance from head origin to MRI origin: " - f"{1000 * np.linalg.norm(trans["trans"][:3, 3]):.1f} mm" + f"{1000 * np.linalg.norm(trans['trans'][:3, 3]):.1f} mm" ) dists = mne.dig_mri_distances(raw.info, trans, "sample", subjects_dir=subjects_dir) print( diff --git a/tutorials/forward/30_forward.py b/tutorials/forward/30_forward.py index 6c55d0bfe3c..72731982962 100644 --- a/tutorials/forward/30_forward.py +++ b/tutorials/forward/30_forward.py @@ -255,7 +255,7 @@ # or ``inv['src']`` so that this removal is adequately accounted for. print(f"Before: {src}") -print(f'After: {fwd["src"]}') +print(f"After: {fwd['src']}") # %% # We can explore the content of ``fwd`` to access the numpy array that contains diff --git a/tutorials/intro/15_inplace.py b/tutorials/intro/15_inplace.py index 0c68843d4c8..01e8c1f7eb0 100644 --- a/tutorials/intro/15_inplace.py +++ b/tutorials/intro/15_inplace.py @@ -60,9 +60,9 @@ # Another group of methods where data is modified in-place are the # channel-picking methods. For example: -print(f'original data had {original_raw.info["nchan"]} channels.') +print(f"original data had {original_raw.info['nchan']} channels.") original_raw.pick("eeg") # selects only the EEG channels -print(f'after picking, it has {original_raw.info["nchan"]} channels.') +print(f"after picking, it has {original_raw.info['nchan']} channels.") # %% diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 5eeb7b79d64..257b1f85051 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -291,8 +291,7 @@ # This time, print as percentage. ratio_percent = round(100 * explained_var_ratio["eeg"]) print( - f"Fraction of variance in EEG signal explained by first component: " - f"{ratio_percent}%" + f"Fraction of variance in EEG signal explained by first component: {ratio_percent}%" ) # %% diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 57be25803d5..28dee357f9a 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -390,6 +390,13 @@ # # See the documentation of each function for further details. # +# .. note:: +# In situations only limited electrodes are available for analysis, removing the +# cardiac artefact using techniques which rely on the availability of spatial +# information (such as SSP) may not be possible. In these instances, it may be of +# use to consider algorithms which require information only regarding heartbeat +# instances in the time domain, such as :func:`mne.preprocessing.apply_pca_obs`. +# # # Repairing EOG artifacts with SSP # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -520,7 +527,7 @@ evoked_eeg.plot(proj=proj, axes=ax, spatial_colors=True) parts = ax.get_title().split("(") ylabel = ( - f'{parts[0]} ({ax.get_ylabel()})\n{parts[1].replace(")", "")}' + f"{parts[0]} ({ax.get_ylabel()})\n{parts[1].replace(')', '')}" if pi == 0 else "" ) @@ -535,6 +542,7 @@ # reduced the amplitude of our signals in sensor space, but that it should not # bias the amplitudes in source space. # +# # References # ^^^^^^^^^^ #