From c5dc859a8d06649f7e05060706ae76e4b38a45c6 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 5 Dec 2023 21:37:11 -0500 Subject: [PATCH] Ran black and isort and ruff Signed-off-by: Adam Li --- .codespellignore | 2 + .github/workflows/code-style.yaml | 2 +- .pre-commit-config.yaml | 8 +- .yamllint.yaml => .yamllint.yml | 0 doc/_templates/autosummary/class.rst | 10 +- doc/_templates/autosummary/function.rst | 10 +- doc/conf.py | 316 +++++++++++------- ...mpare_connectivity_over_time_over_trial.py | 138 +++++--- examples/connectivity_classes.py | 52 ++- examples/cwt_sensor_connectivity.py | 51 +-- examples/dpli_wpli_pli.py | 124 +++++-- examples/dynamic/mne_var_connectivity.py | 59 ++-- examples/granger_causality.py | 119 ++++--- examples/handling_ragged_arrays.py | 30 +- examples/mic_mim.py | 130 +++---- examples/mixed_source_space_connectivity.py | 141 ++++---- examples/mne_inverse_coherence_epochs.py | 93 ++++-- examples/mne_inverse_connectivity_spectrum.py | 74 ++-- examples/mne_inverse_envelope_correlation.py | 91 ++--- ...mne_inverse_envelope_correlation_volume.py | 54 +-- examples/mne_inverse_label_connectivity.py | 112 ++++--- examples/mne_inverse_psi_visual.py | 85 +++-- examples/sensor_connectivity.py | 48 ++- pyproject.toml | 53 ++- 24 files changed, 1125 insertions(+), 677 deletions(-) rename .yamllint.yaml => .yamllint.yml (100%) diff --git a/.codespellignore b/.codespellignore index b0e151c8..b47c3552 100644 --- a/.codespellignore +++ b/.codespellignore @@ -2,3 +2,5 @@ raison fro nd manuel +ba +master \ No newline at end of file diff --git a/.github/workflows/code-style.yaml b/.github/workflows/code-style.yaml index 7b81d5f3..e74270be 100644 --- a/.github/workflows/code-style.yaml +++ b/.github/workflows/code-style.yaml @@ -44,4 +44,4 @@ jobs: - name: Run toml-sort run: toml-sort pyproject.toml --check - name: Run yamllint - run: yamllint . -c .yamllint.yaml --strict + run: yamllint . -c .yamllint.yml --strict diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 25d15b21..edd7bc18 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,14 +5,14 @@ repos: - id: black args: [--quiet] - # Ruff mne + # Ruff mne_connectivity - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.6 hooks: - id: ruff - name: ruff mne + name: ruff mne_connectivity args: ["--fix"] - files: ^mne/ + files: ^mne_connectivity/ # Ruff tutorials and examples - repo: https://github.com/astral-sh/ruff-pre-commit @@ -32,7 +32,7 @@ repos: - id: codespell additional_dependencies: - tomli - files: ^mne/|^doc/|^examples/|^tutorials/ + files: ^mne_connectivity/|^doc/|^examples/|^tutorials/ types_or: [python, bib, rst, inc] # yamllint diff --git a/.yamllint.yaml b/.yamllint.yml similarity index 100% rename from .yamllint.yaml rename to .yamllint.yml diff --git a/doc/_templates/autosummary/class.rst b/doc/_templates/autosummary/class.rst index 6056ea9c..fe474401 100644 --- a/doc/_templates/autosummary/class.rst +++ b/doc/_templates/autosummary/class.rst @@ -1,10 +1,12 @@ -{{ fullname }} -{{ underline }} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} - :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__,__hash__ + :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__ :members: -.. include:: {{module}}.{{objname}}.examples +.. _sphx_glr_backreferences_{{ fullname }}: + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst index bdde2420..bd78b8e8 100644 --- a/doc/_templates/autosummary/function.rst +++ b/doc/_templates/autosummary/function.rst @@ -1,12 +1,10 @@ -{{ fullname }} -{{ underline }} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autofunction:: {{ objname }} -.. include:: {{module}}.{{objname}}.examples +.. _sphx_glr_backreferences_{{ fullname }}: -.. raw:: html - -
+.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/conf.py b/doc/conf.py index 3dd9883f..076557a3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -8,6 +8,7 @@ from sphinx_gallery.sorting import ExampleTitleSortKey import mne + sys.path.insert(0, os.path.abspath("..")) import mne_connectivity # noqa: E402 @@ -15,32 +16,32 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. curdir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(curdir, '..'))) -sys.path.append(os.path.abspath(os.path.join(curdir, '..', 'mne_connectivity'))) -sys.path.append(os.path.abspath(os.path.join(curdir, 'sphinxext'))) +sys.path.append(os.path.abspath(os.path.join(curdir, ".."))) +sys.path.append(os.path.abspath(os.path.join(curdir, "..", "mne_connectivity"))) +sys.path.append(os.path.abspath(os.path.join(curdir, "sphinxext"))) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. # -needs_sphinx = '4.0' +needs_sphinx = "4.0" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx_autodoc_typehints', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx_gallery.gen_gallery', - 'sphinxcontrib.bibtex', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx_autodoc_typehints", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx_gallery.gen_gallery", + "sphinxcontrib.bibtex", "sphinx_issues", - 'numpydoc', - 'sphinx_copybutton', + "numpydoc", + "sphinx_copybutton", ] # configure sphinx-issues @@ -54,8 +55,8 @@ # -- sphinx.ext.autosummary autosummary_generate = True -autodoc_default_options = {'inherited-members': None} -autodoc_typehints = 'signature' +autodoc_default_options = {"inherited-members": None} +autodoc_typehints = "signature" # prevent jupyter notebooks from being run even if empty cell # nbsphinx_execute = 'never' @@ -63,14 +64,15 @@ error_ignores = { # These we do not live by: - 'GL01', # Docstring should start in the line immediately after the quotes - 'EX01', 'EX02', # examples failed (we test them separately) - 'ES01', # no extended summary - 'SA01', # no see also - 'YD01', # no yields section - 'SA04', # no description in See Also - 'PR04', # Parameter "shape (n_channels" has no type - 'RT02', # The first line of the Returns section should contain only the type, unless multiple values are being returned # noqa + "GL01", # Docstring should start in the line immediately after the quotes + "EX01", + "EX02", # examples failed (we test them separately) + "ES01", # no extended summary + "SA01", # no see also + "YD01", # no yields section + "SA04", # no description in See Also + "PR04", # Parameter "shape (n_channels" has no type + "RT02", # The first line of the Returns section should contain only the type, unless multiple values are being returned # noqa # XXX should also verify that | is used rather than , to separate params # XXX should maybe also restore the parameter-desc-length < 800 char check } @@ -83,84 +85,147 @@ numpydoc_use_blockquotes = True numpydoc_xref_ignore = { # words - 'instance', 'instances', 'of', 'default', 'shape', 'or', - 'with', 'length', 'pair', 'matplotlib', 'optional', 'kwargs', 'in', - 'dtype', 'object', 'self.verbose', + "instance", + "instances", + "of", + "default", + "shape", + "or", + "with", + "length", + "pair", + "matplotlib", + "optional", + "kwargs", + "in", + "dtype", + "object", + "self.verbose", # shapes - 'n_times', 'obj', 'n_chan', 'n_epochs', 'n_picks', 'n_ch_groups', - 'n_node_names', 'n_tapers', 'n_signals', 'n_step', 'n_freqs', - 'epochs', 'freqs', 'times', 'arrays', 'lists', 'func', 'n_nodes', - 'n_estimated_nodes', 'n_samples', 'n_channels', 'Renderer', - 'n_ytimes', 'n_ychannels', 'n_events', 'n_cons', 'max_n_chans', - 'n_unique_seeds', 'n_unique_targets', 'variable' + "n_times", + "obj", + "n_chan", + "n_epochs", + "n_picks", + "n_ch_groups", + "n_node_names", + "n_tapers", + "n_signals", + "n_step", + "n_freqs", + "epochs", + "freqs", + "times", + "arrays", + "lists", + "func", + "n_nodes", + "n_estimated_nodes", + "n_samples", + "n_channels", + "Renderer", + "n_ytimes", + "n_ychannels", + "n_events", + "n_cons", + "max_n_chans", + "n_unique_seeds", + "n_unique_targets", + "variable", } numpydoc_xref_aliases = { # Python - 'file-like': ':term:`file-like `', + "file-like": ":term:`file-like `", # Matplotlib - 'colormap': ':doc:`colormap `', - 'color': ':doc:`color `', - 'collection': ':doc:`collections `', - 'Axes': 'matplotlib.axes.Axes', - 'Figure': 'matplotlib.figure.Figure', - 'Axes3D': 'mpl_toolkits.mplot3d.axes3d.Axes3D', - 'PolarAxes': 'matplotlib.projections.polar.PolarAxes', - 'ColorbarBase': 'matplotlib.colorbar.ColorbarBase', + "colormap": ":doc:`colormap `", + "color": ":doc:`color `", + "collection": ":doc:`collections `", + "Axes": "matplotlib.axes.Axes", + "Figure": "matplotlib.figure.Figure", + "Axes3D": "mpl_toolkits.mplot3d.axes3d.Axes3D", + "PolarAxes": "matplotlib.projections.polar.PolarAxes", + "ColorbarBase": "matplotlib.colorbar.ColorbarBase", # joblib - 'joblib.Parallel': 'joblib.Parallel', + "joblib.Parallel": "joblib.Parallel", # nibabel - 'Nifti1Image': 'nibabel.nifti1.Nifti1Image', - 'Nifti2Image': 'nibabel.nifti2.Nifti2Image', - 'SpatialImage': 'nibabel.spatialimages.SpatialImage', + "Nifti1Image": "nibabel.nifti1.Nifti1Image", + "Nifti2Image": "nibabel.nifti2.Nifti2Image", + "SpatialImage": "nibabel.spatialimages.SpatialImage", # MNE - 'Label': 'mne.Label', 'Forward': 'mne.Forward', 'Evoked': 'mne.Evoked', - 'Info': 'mne.Info', 'SourceSpaces': 'mne.SourceSpaces', - 'SourceMorph': 'mne.SourceMorph', - 'Epochs': 'mne.Epochs', 'Layout': 'mne.channels.Layout', - 'EvokedArray': 'mne.EvokedArray', 'BiHemiLabel': 'mne.BiHemiLabel', - 'AverageTFR': 'mne.time_frequency.AverageTFR', - 'EpochsTFR': 'mne.time_frequency.EpochsTFR', - 'Raw': 'mne.io.Raw', 'ICA': 'mne.preprocessing.ICA', + "Label": "mne.Label", + "Forward": "mne.Forward", + "Evoked": "mne.Evoked", + "Info": "mne.Info", + "SourceSpaces": "mne.SourceSpaces", + "SourceMorph": "mne.SourceMorph", + "Epochs": "mne.Epochs", + "Layout": "mne.channels.Layout", + "EvokedArray": "mne.EvokedArray", + "BiHemiLabel": "mne.BiHemiLabel", + "AverageTFR": "mne.time_frequency.AverageTFR", + "EpochsTFR": "mne.time_frequency.EpochsTFR", + "Raw": "mne.io.Raw", + "ICA": "mne.preprocessing.ICA", # MNE-Connectivity - 'Connectivity': 'mne_connectivity.Connectivity', + "Connectivity": "mne_connectivity.Connectivity", # dipy - 'dipy.align.AffineMap': 'dipy.align.imaffine.AffineMap', - 'dipy.align.DiffeomorphicMap': 'dipy.align.imwarp.DiffeomorphicMap', + "dipy.align.AffineMap": "dipy.align.imaffine.AffineMap", + "dipy.align.DiffeomorphicMap": "dipy.align.imwarp.DiffeomorphicMap", } numpydoc_validate = True -numpydoc_validation_checks = {'all'} | set(error_ignores) +numpydoc_validation_checks = {"all"} | set(error_ignores) numpydoc_validation_exclude = { # set of regex # dict subclasses - r'\.clear', r'\.get$', r'\.copy$', r'\.fromkeys', r'\.items', r'\.keys', - r'\.pop', r'\.popitem', r'\.setdefault', r'\.update', r'\.values', + r"\.clear", + r"\.get$", + r"\.copy$", + r"\.fromkeys", + r"\.items", + r"\.keys", + r"\.pop", + r"\.popitem", + r"\.setdefault", + r"\.update", + r"\.values", # list subclasses - r'\.append', r'\.count', r'\.extend', r'\.index', r'\.insert', r'\.remove', - r'\.sort', + r"\.append", + r"\.count", + r"\.extend", + r"\.index", + r"\.insert", + r"\.remove", + r"\.sort", # we currently don't document these properly (probably okay) - r'\.__getitem__', r'\.__contains__', r'\.__hash__', r'\.__mul__', - r'\.__sub__', r'\.__add__', r'\.__iter__', r'\.__div__', r'\.__neg__', - r'plot_circle' + r"\.__getitem__", + r"\.__contains__", + r"\.__hash__", + r"\.__mul__", + r"\.__sub__", + r"\.__add__", + r"\.__iter__", + r"\.__div__", + r"\.__neg__", + r"plot_circle", } -default_role = 'py:obj' +default_role = "py:obj" # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'MNE-Connectivity' +project = "MNE-Connectivity" td = date.today() -copyright = '2021-%s, MNE Developers. Last updated on %s' % (td.year, - td.isoformat()) +copyright = "2021-%s, MNE Developers. Last updated on %s" % (td.year, td.isoformat()) -author = 'Adam Li' +author = "Adam Li" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -173,7 +238,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', "**.ipynb_checkpoints"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] # HTML options (e.g., theme) # see: https://sphinx-bootstrap-theme.readthedocs.io/en/latest/README.html @@ -181,12 +246,12 @@ html_show_sourcelink = False html_copy_source = False -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] -html_static_path = ['_static'] -html_css_files = ['style.css'] +templates_path = ["_templates"] +html_static_path = ["_static"] +html_css_files = ["style.css"] switcher_version_match = "dev" if "dev" in release else version @@ -194,15 +259,17 @@ # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'icon_links': [ - dict(name='GitHub', - url='https://github.com/mne-tools/mne-connectivity', - icon='fab fa-github-square'), + "icon_links": [ + dict( + name="GitHub", + url="https://github.com/mne-tools/mne-connectivity", + icon="fab fa-github-square", + ), ], - 'use_edit_page_button': False, - 'navigation_with_keys': False, - 'show_toc_level': 1, - 'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'], + "use_edit_page_button": False, + "navigation_with_keys": False, + "show_toc_level": 1, + "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"], "switcher": { "json_url": "https://mne.tools/mne-connectivity/dev/_static/versions.json", "version_match": switcher_version_match, @@ -210,20 +277,22 @@ } # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'mne': ('https://mne.tools/dev', None), - 'mne-bids': ('https://mne.tools/mne-bids/dev/', None), - 'numpy': ('https://numpy.org/devdocs', None), - 'scipy': ('https://scipy.github.io/devdocs', None), - 'matplotlib': ('https://matplotlib.org/stable', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/dev', None), - 'sklearn': ('https://scikit-learn.org/stable', None), - 'pyvista': ('https://docs.pyvista.org', None), - 'joblib': ('https://joblib.readthedocs.io/en/latest', None), - 'nibabel': ('https://nipy.org/nibabel', None), - 'nilearn': ('http://nilearn.github.io/stable', None), - 'dipy': ('https://dipy.org/documentation/1.4.0./', - 'https://dipy.org/documentation/1.4.0./objects.inv/'), + "python": ("https://docs.python.org/3", None), + "mne": ("https://mne.tools/dev", None), + "mne-bids": ("https://mne.tools/mne-bids/dev/", None), + "numpy": ("https://numpy.org/devdocs", None), + "scipy": ("https://scipy.github.io/devdocs", None), + "matplotlib": ("https://matplotlib.org/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/dev", None), + "sklearn": ("https://scikit-learn.org/stable", None), + "pyvista": ("https://docs.pyvista.org", None), + "joblib": ("https://joblib.readthedocs.io/en/latest", None), + "nibabel": ("https://nipy.org/nibabel", None), + "nilearn": ("http://nilearn.github.io/stable", None), + "dipy": ( + "https://dipy.org/documentation/1.4.0./", + "https://dipy.org/documentation/1.4.0./objects.inv/", + ), } intersphinx_timeout = 5 @@ -233,13 +302,13 @@ # instead of in the root." # we will store dev docs in a `dev` subdirectory and all other docs in a # directory "v" + version_str. E.g., "v0.3" -if 'dev' in version: - filepath_prefix = 'dev' +if "dev" in version: + filepath_prefix = "dev" else: - filepath_prefix = 'v{}'.format(version) + filepath_prefix = "v{}".format(version) -os.environ['_MNE_BUILDING_DOC'] = 'true' -scrapers = ('matplotlib',) +os.environ["_MNE_BUILDING_DOC"] = "true" +scrapers = ("matplotlib",) try: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -249,34 +318,35 @@ except Exception: pass else: - scrapers += ('pyvista',) -if 'pyvista' in scrapers: + scrapers += ("pyvista",) +if "pyvista" in scrapers: import mne.viz._brain + brain_scraper = mne.viz._brain._BrainScraper() scrapers = list(scrapers) - scrapers.insert(scrapers.index('pyvista'), brain_scraper) + scrapers.insert(scrapers.index("pyvista"), brain_scraper) scrapers = tuple(scrapers) sphinx_gallery_conf = { - 'doc_module': 'mne_connectivity', - 'reference_url': { - 'mne_connectivity': None, + "doc_module": "mne_connectivity", + "reference_url": { + "mne_connectivity": None, }, - 'backreferences_dir': 'generated', - 'plot_gallery': 'True', # Avoid annoying Unicode/bool default warning - 'within_subsection_order': ExampleTitleSortKey, - 'examples_dirs': ['../examples'], - 'gallery_dirs': ['auto_examples'], - 'filename_pattern': '^((?!sgskip).)*$', - 'matplotlib_animations': True, - 'compress_images': ('images', 'thumbnails'), - 'image_scrapers': scrapers, + "backreferences_dir": "generated", + "plot_gallery": "True", # Avoid annoying Unicode/bool default warning + "within_subsection_order": ExampleTitleSortKey, + "examples_dirs": ["../examples"], + "gallery_dirs": ["auto_examples"], + "filename_pattern": "^((?!sgskip).)*$", + "matplotlib_animations": True, + "compress_images": ("images", "thumbnails"), + "image_scrapers": scrapers, } # sphinxcontrib-bibtex -bibtex_bibfiles = ['./references.bib'] -bibtex_style = 'unsrt' -bibtex_footbibliography_header = '' +bibtex_bibfiles = ["./references.bib"] +bibtex_style = "unsrt" +bibtex_footbibliography_header = "" # Enable nitpicky mode - which ensures that all references in the docs diff --git a/examples/compare_connectivity_over_time_over_trial.py b/examples/compare_connectivity_over_time_over_trial.py index 96beae8a..cd5afc09 100644 --- a/examples/compare_connectivity_over_time_over_trial.py +++ b/examples/compare_connectivity_over_time_over_trial.py @@ -78,8 +78,7 @@ import matplotlib.pyplot as plt import mne -from mne_connectivity import (spectral_connectivity_epochs, - spectral_connectivity_time) +from mne_connectivity import spectral_connectivity_epochs, spectral_connectivity_time from mne_connectivity.viz import plot_sensors_connectivity from mne.datasets import sample @@ -118,9 +117,7 @@ # First we compute connectivity over trials. # Freq bands of interest -Freq_Bands = {"theta": [4.0, 8.0], - "alpha": [8.0, 13.0], - "beta": [13.0, 30.0]} +Freq_Bands = {"theta": [4.0, 8.0], "alpha": [8.0, 13.0], "beta": [13.0, 30.0]} n_freq_bands = len(Freq_Bands) min_freq = np.min(list(Freq_Bands.values())) max_freq = np.max(list(Freq_Bands.values())) @@ -137,16 +134,22 @@ n_con_methods = len(connectivity_methods) # Pre-allocatate memory for the connectivity matrices -con_epochs_array = np.zeros((n_con_methods, n_channels, n_channels, - n_freq_bands, n_times)) +con_epochs_array = np.zeros( + (n_con_methods, n_channels, n_channels, n_freq_bands, n_times) +) con_epochs_array[con_epochs_array == 0] = np.nan # nan matrix # Compute connectivity over trials -con_epochs = spectral_connectivity_epochs(data_epoch, - method=connectivity_methods, - sfreq=sfreq, mode="cwt_morlet", - cwt_freqs=freqs, fmin=fmin, - fmax=fmax, faverage=True) +con_epochs = spectral_connectivity_epochs( + data_epoch, + method=connectivity_methods, + sfreq=sfreq, + mode="cwt_morlet", + cwt_freqs=freqs, + fmin=fmin, + fmax=fmax, + faverage=True, +) # Get data as connectivity matrices for c in range(n_con_methods): @@ -175,21 +178,23 @@ def plot_con_matrix(con_data, n_con_methods): fig, ax = plt.subplots(1, n_con_methods, figsize=(6 * n_con_methods, 6)) for c in range(n_con_methods): # Plot with imshow - con_plot = ax[c].imshow(con_data[c, :, :, foi], - cmap="binary", vmin=0, vmax=1) + con_plot = ax[c].imshow(con_data[c, :, :, foi], cmap="binary", vmin=0, vmax=1) # Set title ax[c].set_title(connectivity_methods[c]) # Add colorbar - fig.colorbar(con_plot, ax=ax[c], shrink=0.7, label='Connectivity') + fig.colorbar(con_plot, ax=ax[c], shrink=0.7, label="Connectivity") # Fix labels ax[c].set_xticks(range(len(ch_names))) ax[c].set_xticklabels(ch_names) ax[c].set_yticks(range(len(ch_names))) ax[c].set_yticklabels(ch_names) - print(f"Connectivity method: {connectivity_methods[c]}\n" + - f"{con_data[c,:,:,foi]}") + print( + f"Connectivity method: {connectivity_methods[c]}\n" + + f"{con_data[c,:,:,foi]}" + ) return fig + plot_con_matrix(con_epochs_array, n_con_methods) ############################################################################### @@ -199,16 +204,22 @@ def plot_con_matrix(con_data, n_con_methods): # We will now compute connectivity over time. # Pre-allocatate memory for the connectivity matrices -con_time_array = np.zeros((n_con_methods, n_epochs, n_channels, - n_channels, n_freq_bands)) +con_time_array = np.zeros( + (n_con_methods, n_epochs, n_channels, n_channels, n_freq_bands) +) con_time_array[con_time_array == 0] = np.nan # nan matrix # Compute connectivity over time -con_time = spectral_connectivity_time(data_epoch, freqs, - method=connectivity_methods, - sfreq=sfreq, mode="cwt_morlet", - fmin=fmin, fmax=fmax, - faverage=True) +con_time = spectral_connectivity_time( + data_epoch, + freqs, + method=connectivity_methods, + sfreq=sfreq, + mode="cwt_morlet", + fmin=fmin, + fmax=fmax, + faverage=True, +) # Get data as connectivity matrices for c in range(n_con_methods): @@ -244,8 +255,11 @@ def plot_con_matrix(con_data, n_con_methods): epoch_len = n_times / sfreq phase = rng.random(1) * 10 # Introduce random phase for each channel # Generate sinus wave - x = np.linspace(-wave_freq * epoch_len * np.pi + phase, - wave_freq * epoch_len * np.pi + phase, n_times) + x = np.linspace( + -wave_freq * epoch_len * np.pi + phase, + wave_freq * epoch_len * np.pi + phase, + n_times, + ) data[i, c] = np.squeeze(np.sin(x)) # overwrite to data data_epoch = mne.EpochsArray(data, info) # create EpochsArray @@ -257,16 +271,22 @@ def plot_con_matrix(con_data, n_con_methods): # First we compute connectivity over trials. # Pre-allocatate memory for the connectivity matrices -con_epochs_array = np.zeros((n_con_methods, n_channels, n_channels, - n_freq_bands, n_times)) +con_epochs_array = np.zeros( + (n_con_methods, n_channels, n_channels, n_freq_bands, n_times) +) con_epochs_array[con_epochs_array == 0] = np.nan # nan matrix # Compute connecitivty over trials -con_epochs = spectral_connectivity_epochs(data_epoch, - method=connectivity_methods, - sfreq=sfreq, mode="cwt_morlet", - cwt_freqs=freqs, fmin=fmin, - fmax=fmax, faverage=True) +con_epochs = spectral_connectivity_epochs( + data_epoch, + method=connectivity_methods, + sfreq=sfreq, + mode="cwt_morlet", + cwt_freqs=freqs, + fmin=fmin, + fmax=fmax, + faverage=True, +) # Get data as connectivity matrices for c in range(n_con_methods): @@ -285,14 +305,20 @@ def plot_con_matrix(con_data, n_con_methods): # We will now compute connectivity over time. # Pre-allocatate memory for the connectivity matrices -con_time_array = np.zeros((n_con_methods, n_epochs, - n_channels, n_channels, n_freq_bands)) +con_time_array = np.zeros( + (n_con_methods, n_epochs, n_channels, n_channels, n_freq_bands) +) con_time_array[con_time_array == 0] = np.nan # nan matrix -con_time = spectral_connectivity_time(data_epoch, freqs, - method=connectivity_methods, - sfreq=sfreq, fmin=fmin, fmax=fmax, - faverage=True) +con_time = spectral_connectivity_time( + data_epoch, + freqs, + method=connectivity_methods, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, +) # Get data as connectivity matrices for c in range(n_con_methods): @@ -313,19 +339,19 @@ def plot_con_matrix(con_data, n_con_methods): # To finish this example, we will compute connectivity for a sample EEG data. data_path = sample.data_path() -raw_fname = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw.fif' -event_fname = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' +raw_fname = data_path / "MEG/sample/sample_audvis_filt-0-40_raw.fif" +event_fname = data_path / "MEG/sample/sample_audvis_filt-0-40_raw-eve.fif" raw = mne.io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) # Select only the EEG -picks = mne.pick_types(raw.info, meg=False, eeg=True, - stim=False, eog=False, exclude='bads') +picks = mne.pick_types( + raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads" +) # Create epochs for left visual field stimulus event_id, tmin, tmax = 3, -0.3, 1.6 -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0)) +epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0)) epochs.load_data() # load the data ############################################################################### @@ -340,7 +366,7 @@ def plot_con_matrix(con_data, n_con_methods): # N1 :footcite:`KlimeschEtAl2004`. Here, we will therefore analyze phase # connectivity in the theta band around P1 -sfreq = epochs.info['sfreq'] # the sampling frequency +sfreq = epochs.info["sfreq"] # the sampling frequency tmin = 0.0 # exclude the baseline period for connectivity estimation Freq_Bands = {"theta": [4.0, 8.0]} # frequency of interest n_freq_bands = len(Freq_Bands) @@ -358,12 +384,18 @@ def plot_con_matrix(con_data, n_con_methods): n_con_methods = len(connectivity_methods) # Compute connectivity over trials -con_epochs = spectral_connectivity_epochs(epochs, - method=connectivity_methods, - sfreq=sfreq, mode="cwt_morlet", - cwt_freqs=freqs, fmin=fmin, - fmax=fmax, faverage=True, - tmin=tmin, cwt_n_cycles=4) +con_epochs = spectral_connectivity_epochs( + epochs, + method=connectivity_methods, + sfreq=sfreq, + mode="cwt_morlet", + cwt_freqs=freqs, + fmin=fmin, + fmax=fmax, + faverage=True, + tmin=tmin, + cwt_n_cycles=4, +) ############################################################################### # Notice we have shortened the wavelets to 4 cycles since we only have 1.6s @@ -403,7 +435,7 @@ def plot_con_matrix(con_data, n_con_methods): fig = plt.figure() im = plt.imshow(con_epochs_matrix) -fig.colorbar(im, label='Connectivity') +fig.colorbar(im, label="Connectivity") plt.ylabel("Channels") plt.xlabel("Channels") plt.show() diff --git a/examples/connectivity_classes.py b/examples/connectivity_classes.py index e5679b42..a42b6d68 100644 --- a/examples/connectivity_classes.py +++ b/examples/connectivity_classes.py @@ -26,38 +26,54 @@ # %% # Set parameters data_path = sample.data_path() -raw_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_filt-0-40_raw.fif') -event_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_filt-0-40_raw-eve.fif') +raw_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif") +event_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif") # Setup for reading the raw data raw = mne.io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) # Add a bad channel -raw.info['bads'] += ['MEG 2443'] +raw.info["bads"] += ["MEG 2443"] # Pick MEG gradiometers -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) # Create epochs for the visual condition event_id, tmin, tmax = 3, -0.2, 1.5 # need a long enough epoch for 5 cycles -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), +) # Compute connectivity for the alpha band that contains the evoked response # (4-9 Hz). We exclude the baseline period: -fmin, fmax = 4., 9. +fmin, fmax = 4.0, 9.0 cwt_freqs = np.linspace(fmin, fmax, 5) -sfreq = raw.info['sfreq'] # the sampling frequency +sfreq = raw.info["sfreq"] # the sampling frequency tmin = 0.0 # exclude the baseline period -epochs.load_data().pick_types(meg='grad') # just keep MEG and no EOG now +epochs.load_data().pick_types(meg="grad") # just keep MEG and no EOG now con = spectral_connectivity_epochs( - epochs, method='pli', mode='cwt_morlet', sfreq=sfreq, fmin=fmin, fmax=fmax, - faverage=False, tmin=tmin, cwt_freqs=cwt_freqs, mt_adaptive=False, - n_jobs=1) + epochs, + method="pli", + mode="cwt_morlet", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=False, + tmin=tmin, + cwt_freqs=cwt_freqs, + mt_adaptive=False, + n_jobs=1, +) # %% # Now, we can look at different functionalities of the connectivity @@ -94,7 +110,7 @@ print(con.shape) # the 'dense' output will show the connectivity measure's N x N axis -print(con.get_data(output='dense').shape) +print(con.get_data(output="dense").shape) # %% Connectivity Measure XArray Attributes # The underlying data is stored as an xarray, so we have access @@ -102,11 +118,11 @@ # stores relevant metadata. For example, the method used in this example # is the phase-lag index ('pli'). print(con.attrs.keys()) -print(con.attrs.get('method')) +print(con.attrs.get("method")) # You can also store additional metadata relevant to your experiment, which can # easily be done, because ``attrs`` is just a dictionary. -con.attrs['experimenter'] = 'mne' +con.attrs["experimenter"] = "mne" print(con.attrs.keys()) # %% diff --git a/examples/cwt_sensor_connectivity.py b/examples/cwt_sensor_connectivity.py index f4930f1b..3be7913c 100644 --- a/examples/cwt_sensor_connectivity.py +++ b/examples/cwt_sensor_connectivity.py @@ -26,30 +26,37 @@ ############################################################################### # Set parameters data_path = sample.data_path() -raw_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_filt-0-40_raw.fif') -event_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_filt-0-40_raw-eve.fif') +raw_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif") +event_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif") # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) # Add a bad channel -raw.info['bads'] += ['MEG 2443'] +raw.info["bads"] += ["MEG 2443"] # Pick MEG gradiometers -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) # Create epochs for left-visual condition event_id, tmin, tmax = 3, -0.2, 0.5 -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6), - preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), + preload=True, +) # Use 'MEG 2343' as seed -seed_ch = 'MEG 2343' +seed_ch = "MEG 2343" picks_ch_names = [raw.ch_names[i] for i in picks] # Create seed-target indices for connectivity computation @@ -59,14 +66,20 @@ # Define wavelet frequencies and number of cycles cwt_freqs = np.arange(7, 30, 2) -cwt_n_cycles = cwt_freqs / 7. +cwt_n_cycles = cwt_freqs / 7.0 # Run the connectivity analysis using 2 parallel jobs -sfreq = raw.info['sfreq'] # the sampling frequency +sfreq = raw.info["sfreq"] # the sampling frequency con = spectral_connectivity_epochs( - epochs, indices=indices, - method='wpli2_debiased', mode='cwt_morlet', sfreq=sfreq, - cwt_freqs=cwt_freqs, cwt_n_cycles=cwt_n_cycles, n_jobs=1) + epochs, + indices=indices, + method="wpli2_debiased", + mode="cwt_morlet", + sfreq=sfreq, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + n_jobs=1, +) times = con.times freqs = con.freqs @@ -74,12 +87,12 @@ con.get_data()[np.where(indices[1] == seed)] = 1.0 # Show topography of connectivity from seed -title = 'WPLI2 - Visual - Seed %s' % seed_ch +title = "WPLI2 - Visual - Seed %s" % seed_ch -layout = mne.find_layout(epochs.info, 'meg') # use full layout +layout = mne.find_layout(epochs.info, "meg") # use full layout tfr = AverageTFR(epochs.info, con.get_data(), times, freqs, len(epochs)) -tfr.plot_topo(fig_facecolor='w', font_color='k', border='k') +tfr.plot_topo(fig_facecolor="w", font_color="k", border="k") ############################################################################### diff --git a/examples/dpli_wpli_pli.py b/examples/dpli_wpli_pli.py index 1fd5cbbc..d53096de 100644 --- a/examples/dpli_wpli_pli.py +++ b/examples/dpli_wpli_pli.py @@ -1,4 +1,4 @@ -''' +""" ============================= Comparing PLI, wPLI, and dPLI ============================= @@ -7,7 +7,7 @@ the phase lag index (PLI) :footcite:`StamEtAl2007`, weighted phase lag index (wPLI) :footcite:`VinckEtAl2011`, and directed phase lag index (dPLI) :footcite:`StamEtAl2012` on simulated data. -''' +""" # Authors: Kenji Marshall # Charlotte Maschke @@ -102,8 +102,10 @@ for ps in zip(phase_differences): sig = [] for _ in range(n_e): - sig.append(np.sin(2 * np.pi * f * t - ps) + - A * np.random.normal(0, sigma, size=t.shape)) + sig.append( + np.sin(2 * np.pi * f * t - ps) + + A * np.random.normal(0, sigma, size=t.shape) + ) data.append(sig) data = np.swapaxes(np.array(data), 0, 1) # make epochs the first dimension @@ -135,11 +137,18 @@ # %% conn = [] indices = ([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]) -for method in ['pli', 'wpli', 'dpli']: +for method in ["pli", "wpli", "dpli"]: conn.append( spectral_connectivity_epochs( - data, method=method, sfreq=fs, indices=indices, - fmin=9, fmax=11, faverage=True).get_data()[:, 0]) + data, + method=method, + sfreq=fs, + indices=indices, + fmin=9, + fmax=11, + faverage=True, + ).get_data()[:, 0] + ) conn = np.array(conn) ############################################################################### @@ -166,9 +175,9 @@ x = np.arange(5) plt.figure() -plt.bar(x - 0.2, conn[0], 0.2, align='center', label="PLI") -plt.bar(x, conn[1], 0.2, align='center', label="wPLI") -plt.bar(x + 0.2, conn[2], 0.2, align='center', label="dPLI") +plt.bar(x - 0.2, conn[0], 0.2, align="center", label="PLI") +plt.bar(x, conn[1], 0.2, align="center", label="wPLI") +plt.bar(x + 0.2, conn[2], 0.2, align="center", label="dPLI") plt.title("Connectivity Estimation Comparison") plt.xticks(x, (r"$-\pi$", r"$-\pi/2$", r"$0$", r"$\pi/2$", r"$\pi$")) @@ -226,11 +235,11 @@ sig = [] # Generate other signal for _ in range(int(n_e / 2)): # phase difference -pi/100 - sig.append(np.sin(2 * np.pi * f * t + np.pi / - 100 + A * np.random.uniform(-1, 1))) + sig.append( + np.sin(2 * np.pi * f * t + np.pi / 100 + A * np.random.uniform(-1, 1)) + ) for _ in range(int(n_e / 2), n_e): # phase difference pi/2 - sig.append(np.sin(2 * np.pi * f * t - np.pi / - 2 + A * np.random.uniform(-1, 1))) + sig.append(np.sin(2 * np.pi * f * t - np.pi / 2 + A * np.random.uniform(-1, 1))) data.append(sig) data = np.swapaxes(np.array(data), 0, 1) @@ -257,11 +266,18 @@ conn = [] indices = ([0] * n_noise, np.arange(1, n_noise + 1)) -for method in ['pli', 'wpli']: +for method in ["pli", "wpli"]: conn.append( spectral_connectivity_epochs( - data, method=method, sfreq=fs, indices=indices, - fmin=9, fmax=11, faverage=True).get_data()[:, 0]) + data, + method=method, + sfreq=fs, + indices=indices, + fmin=9, + fmax=11, + faverage=True, + ).get_data()[:, 0] + ) conn = np.array(conn) ############################################################################### @@ -298,38 +314,74 @@ # sample MEG data recorded during visual stimulation. data_path = sample.data_path() -raw_fname = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw.fif' -event_fname = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' +raw_fname = data_path / "MEG/sample/sample_audvis_filt-0-40_raw.fif" +event_fname = data_path / "MEG/sample/sample_audvis_filt-0-40_raw-eve.fif" raw = mne.io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) # Select gradiometers -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) # Create epochs event_id, tmin, tmax = 3, -0.2, 1.5 # need a long enough epoch for 5 cycles -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6)) -epochs.load_data().pick_types(meg='grad') # just keep MEG and no EOG now - -fmin, fmax = 4., 9. # compute connectivity within 4-9 Hz -sfreq = raw.info['sfreq'] # the sampling frequency +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), +) +epochs.load_data().pick_types(meg="grad") # just keep MEG and no EOG now + +fmin, fmax = 4.0, 9.0 # compute connectivity within 4-9 Hz +sfreq = raw.info["sfreq"] # the sampling frequency tmin = 0.0 # exclude the baseline period # Compute PLI, wPLI, and dPLI con_pli = spectral_connectivity_epochs( - epochs, method='pli', mode='multitaper', sfreq=sfreq, fmin=fmin, - fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1) + epochs, + method="pli", + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + tmin=tmin, + mt_adaptive=False, + n_jobs=1, +) con_wpli = spectral_connectivity_epochs( - epochs, method='wpli', mode='multitaper', sfreq=sfreq, fmin=fmin, - fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1) + epochs, + method="wpli", + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + tmin=tmin, + mt_adaptive=False, + n_jobs=1, +) con_dpli = spectral_connectivity_epochs( - epochs, method='dpli', mode='multitaper', sfreq=sfreq, fmin=fmin, - fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1) + epochs, + method="dpli", + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + tmin=tmin, + mt_adaptive=False, + n_jobs=1, +) ############################################################################### # In this example, there is strong connectivity between sensors 190-200 and @@ -346,16 +398,16 @@ # strength, as was mentioned earlier. fig, axs = plt.subplots(1, 3, figsize=(14, 5), sharey=True) -axs[0].imshow(con_pli.get_data('dense'), vmin=0, vmax=1) +axs[0].imshow(con_pli.get_data("dense"), vmin=0, vmax=1) axs[0].set_title("PLI") axs[0].set_ylabel("Sensor 1") axs[0].set_xlabel("Sensor 2") -axs[1].imshow(con_wpli.get_data('dense'), vmin=0, vmax=1) +axs[1].imshow(con_wpli.get_data("dense"), vmin=0, vmax=1) axs[1].set_title("wPLI") axs[1].set_xlabel("Sensor 2") -im = axs[2].imshow(con_dpli.get_data('dense'), vmin=0, vmax=1) +im = axs[2].imshow(con_dpli.get_data("dense"), vmin=0, vmax=1) axs[2].set_title("dPLI") axs[2].set_xlabel("Sensor 2") diff --git a/examples/dynamic/mne_var_connectivity.py b/examples/dynamic/mne_var_connectivity.py index 37c8e552..cb366d68 100644 --- a/examples/dynamic/mne_var_connectivity.py +++ b/examples/dynamic/mne_var_connectivity.py @@ -39,16 +39,22 @@ bids_root = mne.datasets.epilepsy_ecog.data_path() # first define the BIDS path -bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery', - task='ictal', datatype='ieeg', extension='vhdr') +bids_path = BIDSPath( + root=bids_root, + subject="pt1", + session="presurgery", + task="ictal", + datatype="ieeg", + extension="vhdr", +) # Then we'll use it to load in the sample dataset. Here we use a format (iEEG) # that is only available in MNE-BIDS 0.7+, so it will emit a warning on # versions <= 0.6 raw = read_raw_bids(bids_path=bids_path, verbose=False) -line_freq = raw.info['line_freq'] -print(f'Data has a power line frequency at {line_freq}.') +line_freq = raw.info["line_freq"] +print(f"Data has a power line frequency at {line_freq}.") # Pick only the ECoG channels, removing the ECG channels raw.pick_types(ecog=True) @@ -60,7 +66,7 @@ raw.notch_filter(line_freq) # drop bad channels -raw.drop_channels(raw.info['bads']) +raw.drop_channels(raw.info["bads"]) # %% # Crop the data for this example @@ -77,10 +83,10 @@ events, event_id = mne.events_from_annotations(raw) # get sample at which seizure starts -onset_id = event_id['onset'] +onset_id = event_id["onset"] onset_idx = np.argwhere(events[:, 2] == onset_id) onset_sample = events[onset_idx, 0].squeeze() -onset_sec = onset_sample / raw.info['sfreq'] +onset_sec = onset_sample / raw.info["sfreq"] # remove all data after the seizure onset raw = raw.crop(tmin=0, tmax=onset_sec, include_tmax=False) @@ -113,8 +119,7 @@ # represents a separate VAR model. Taken together, these represent a # time-varying linear system. -conn = vector_auto_regression( - data=epochs.get_data(), times=times, names=ch_names) +conn = vector_auto_regression(data=epochs.get_data(), times=times, names=ch_names) # this returns a connectivity structure over time print(conn) @@ -134,18 +139,14 @@ # visualize the residuals fig, ax = plt.subplots() -ax.plot(residuals.flatten(), '*') -ax.set( - title='Residuals of fitted VAR model', - ylabel='Magnitude' -) +ax.plot(residuals.flatten(), "*") +ax.set(title="Residuals of fitted VAR model", ylabel="Magnitude") # compute the covariance of the residuals -model_order = conn.attrs.get('model_order') +model_order = conn.attrs.get("model_order") t = residuals.shape[0] sampled_residuals = np.concatenate( - np.split(residuals[:, :, model_order:], t, 0), - axis=2 + np.split(residuals[:, :, model_order:], t, 0), axis=2 ).squeeze(0) rescov = np.cov(sampled_residuals) @@ -155,8 +156,8 @@ # should come with low covariances. fig, ax = plt.subplots() cax = fig.add_axes([0.27, 0.8, 0.5, 0.05]) -im = ax.imshow(rescov, cmap='viridis', aspect='equal', interpolation='none') -fig.colorbar(im, cax=cax, orientation='horizontal') +im = ax.imshow(rescov, cmap="viridis", aspect="equal", interpolation="none") +fig.colorbar(im, cax=cax, orientation="horizontal") # %% # Compute one VAR model using all epochs @@ -167,8 +168,8 @@ # epochs. conn = vector_auto_regression( - data=epochs.get_data(), times=times, names=ch_names, - model='avg-epochs') + data=epochs.get_data(), times=times, names=ch_names, model="avg-epochs" +) # this returns a connectivity structure over time print(conn) @@ -188,18 +189,14 @@ # visualize the residuals fig, ax = plt.subplots() -ax.plot(residuals.flatten(), '*') -ax.set( - title='Residuals of fitted VAR model', - ylabel='Magnitude' -) +ax.plot(residuals.flatten(), "*") +ax.set(title="Residuals of fitted VAR model", ylabel="Magnitude") # compute the covariance of the residuals -model_order = conn.attrs.get('model_order') +model_order = conn.attrs.get("model_order") t = residuals.shape[0] sampled_residuals = np.concatenate( - np.split(residuals[:, :, model_order:], t, 0), - axis=2 + np.split(residuals[:, :, model_order:], t, 0), axis=2 ).squeeze(0) rescov = np.cov(sampled_residuals) @@ -208,5 +205,5 @@ # with the covariances for time-varying VAR model. fig, ax = plt.subplots() cax = fig.add_axes([0.27, 0.8, 0.5, 0.05]) -im = ax.imshow(rescov, cmap='viridis', aspect='equal', interpolation='none') -fig.colorbar(im, cax=cax, orientation='horizontal') +im = ax.imshow(rescov, cmap="viridis", aspect="equal", interpolation="none") +fig.colorbar(im, cax=cax, orientation="horizontal") diff --git a/examples/granger_causality.py b/examples/granger_causality.py index 73c10a86..b7a73cc0 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -135,9 +135,9 @@ # %% -raw = mne.io.read_raw_ctf(data_path() / 'SubjectCMC.ds') -raw.pick('mag') -raw.crop(50., 110.).load_data() +raw = mne.io.read_raw_ctf(data_path() / "SubjectCMC.ds") +raw.pick("mag") +raw.crop(50.0, 110.0).load_data() raw.notch_filter(50) raw.resample(100) @@ -151,22 +151,40 @@ # %% # parietal sensors -signals_a = [idx for idx, ch_info in enumerate(epochs.info['chs']) if - ch_info['ch_name'][2] == 'P'] +signals_a = [ + idx + for idx, ch_info in enumerate(epochs.info["chs"]) + if ch_info["ch_name"][2] == "P" +] # occipital sensors -signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if - ch_info['ch_name'][2] == 'O'] +signals_b = [ + idx + for idx, ch_info in enumerate(epochs.info["chs"]) + if ch_info["ch_name"][2] == "O" +] indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality gc_ab = spectral_connectivity_epochs( - epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, - rank=(np.array([5]), np.array([5])), gc_n_lags=20) # A => B + epochs, + method=["gc"], + indices=indices_ab, + fmin=5, + fmax=30, + rank=(np.array([5]), np.array([5])), + gc_n_lags=20, +) # A => B gc_ba = spectral_connectivity_epochs( - epochs, method=['gc'], indices=indices_ba, fmin=5, fmax=30, - rank=(np.array([5]), np.array([5])), gc_n_lags=20) # B => A + epochs, + method=["gc"], + indices=indices_ba, + fmin=5, + fmax=30, + rank=(np.array([5]), np.array([5])), + gc_n_lags=20, +) # B => A freqs = gc_ab.freqs @@ -179,9 +197,9 @@ fig, axis = plt.subplots(1, 1) axis.plot(freqs, gc_ab.get_data()[0], linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Connectivity (A.U.)') -fig.suptitle('GC: [A => B]') +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +fig.suptitle("GC: [A => B]") ############################################################################### @@ -206,12 +224,11 @@ net_gc = gc_ab.get_data() - gc_ba.get_data() # [A => B] - [B => A] fig, axis = plt.subplots(1, 1) -axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle='--', - color='k') +axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle="--", color="k") axis.plot(freqs, net_gc[0], linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Connectivity (A.U.)') -fig.suptitle('Net GC: [A => B] - [B => A]') +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +fig.suptitle("Net GC: [A => B] - [B => A]") ############################################################################### @@ -263,11 +280,23 @@ # compute GC on time-reversed signals gc_tr_ab = spectral_connectivity_epochs( - epochs, method=['gc_tr'], indices=indices_ab, fmin=5, fmax=30, - rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[A => B] + epochs, + method=["gc_tr"], + indices=indices_ab, + fmin=5, + fmax=30, + rank=(np.array([5]), np.array([5])), + gc_n_lags=20, +) # TR[A => B] gc_tr_ba = spectral_connectivity_epochs( - epochs, method=['gc_tr'], indices=indices_ba, fmin=5, fmax=30, - rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[B => A] + epochs, + method=["gc_tr"], + indices=indices_ba, + fmin=5, + fmax=30, + rank=(np.array([5]), np.array([5])), + gc_n_lags=20, +) # TR[B => A] # compute net GC on time-reversed signals (TR[A => B] - TR[B => A]) net_gc_tr = gc_tr_ab.get_data() - gc_tr_ba.get_data() @@ -289,12 +318,11 @@ # %% fig, axis = plt.subplots(1, 1) -axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle='--', - color='k') +axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle="--", color="k") axis.plot(freqs, trgc[0], linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Connectivity (A.U.)') -fig.suptitle('TRGC: net[A => B] - net time-reversed[A => B]') +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +fig.suptitle("TRGC: net[A => B] - net time-reversed[A => B]") ############################################################################### @@ -318,16 +346,22 @@ # %% gc_ab_60 = spectral_connectivity_epochs( - epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, - rank=(np.array([5]), np.array([5])), gc_n_lags=60) # A => B + epochs, + method=["gc"], + indices=indices_ab, + fmin=5, + fmax=30, + rank=(np.array([5]), np.array([5])), + gc_n_lags=60, +) # A => B fig, axis = plt.subplots(1, 1) -axis.plot(freqs, gc_ab.get_data()[0], linewidth=2, label='20 lags') -axis.plot(freqs, gc_ab_60.get_data()[0], linewidth=2, label='60 lags') -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Connectivity (A.U.)') +axis.plot(freqs, gc_ab.get_data()[0], linewidth=2, label="20 lags") +axis.plot(freqs, gc_ab_60.get_data()[0], linewidth=2, label="60 lags") +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") axis.legend() -fig.suptitle('GC: [A => B]') +fig.suptitle("GC: [A => B]") ############################################################################### @@ -376,11 +410,18 @@ try: spectral_connectivity_epochs( - epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, - gc_n_lags=20, verbose=False) # A => B - print('Success!') + epochs, + method=["gc"], + indices=indices_ab, + fmin=5, + fmax=30, + rank=None, + gc_n_lags=20, + verbose=False, + ) # A => B + print("Success!") except RuntimeError as error: - print('\nCaught the following error:\n' + repr(error)) + print("\nCaught the following error:\n" + repr(error)) ############################################################################### # Rigorous checks are implemented to identify any such instances which would diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py index 4c2d43b1..78522bc5 100644 --- a/examples/handling_ragged_arrays.py +++ b/examples/handling_ragged_arrays.py @@ -101,19 +101,23 @@ # create random data data = np.random.randn(10, 5, 200) # epochs x channels x times sfreq = 50 -ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds - [[2, 3, 4], [4]]) # targets +ragged_indices = ([[0, 1], [0, 1, 2, 3]], [[2, 3, 4], [4]]) # seeds # targets # compute connectivity con = spectral_connectivity_epochs( - data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, - verbose=False) -patterns = np.array(con.attrs['patterns']) + data, + method="mic", + indices=ragged_indices, + sfreq=sfreq, + fmin=10, + fmax=30, + verbose=False, +) +patterns = np.array(con.attrs["patterns"]) padded_indices = con.indices n_freqs = con.get_data().shape[-1] n_cons = len(ragged_indices[0]) -max_n_chans = max( - len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])) +max_n_chans = max(len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])) # show that the padded indices entries are masked assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels @@ -131,14 +135,12 @@ assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels # extract patterns for first connection using the ragged indices -seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] -target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] +seed_patterns_con1 = patterns[0, 0, : len(ragged_indices[0][0])] +target_patterns_con1 = patterns[1, 0, : len(ragged_indices[1][0])] # extract patterns for second connection using the padded, masked indices -seed_patterns_con2 = ( - patterns[0, 1, :padded_indices[0][1].count()]) -target_patterns_con2 = ( - patterns[1, 1, :padded_indices[1][1].count()]) +seed_patterns_con2 = patterns[0, 1, : padded_indices[0][1].count()] +target_patterns_con2 = patterns[1, 1, : padded_indices[1][1].count()] # show that shapes of patterns are correct assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) @@ -146,6 +148,6 @@ assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3) assert target_patterns_con2.shape == (1, n_freqs) # channels (4) -print('Assertions completed successfully!') +print("Assertions completed successfully!") # %% diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 86044969..244e045b 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -59,9 +59,9 @@ # %% -raw = mne.io.read_raw_ctf(data_path() / 'SubjectCMC.ds') -raw.pick('mag') -raw.crop(50., 110.).load_data() +raw = mne.io.read_raw_ctf(data_path() / "SubjectCMC.ds") +raw.pick("mag") +raw.crop(50.0, 110.0).load_data() raw.notch_filter(50) raw.resample(100) @@ -75,26 +75,27 @@ # %% # left hemisphere sensors -seeds = [idx for idx, ch_info in enumerate(epochs.info['chs']) if - ch_info['loc'][0] < 0] +seeds = [idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] < 0] # right hemisphere sensors -targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if - ch_info['loc'][0] > 0] +targets = [ + idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] > 0 +] multivar_indices = (np.array([seeds]), np.array([targets])) -seed_names = [epochs.info['ch_names'][idx] for idx in seeds] -target_names = [epochs.info['ch_names'][idx] for idx in targets] +seed_names = [epochs.info["ch_names"][idx] for idx in seeds] +target_names = [epochs.info["ch_names"][idx] for idx in targets] # multivariate imaginary part of coherency (mic, mim) = spectral_connectivity_epochs( - epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, - rank=None) + epochs, method=["mic", "mim"], indices=multivar_indices, fmin=5, fmax=30, rank=None +) # bivariate imaginary part of coherency (for comparison) bivar_indices = seed_target_indices(seeds, targets) imcoh = spectral_connectivity_epochs( - epochs, method='imcoh', indices=bivar_indices, fmin=5, fmax=30) + epochs, method="imcoh", indices=bivar_indices, fmin=5, fmax=30 +) ############################################################################### # By averaging across each connection between the seeds and targets, we can see @@ -104,11 +105,10 @@ # %% fig, axis = plt.subplots(1, 1) -axis.plot(imcoh.freqs, np.mean(np.abs(imcoh.get_data()), axis=0), - linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Absolute connectivity (A.U.)') -fig.suptitle('Imaginary part of coherency') +axis.plot(imcoh.freqs, np.mean(np.abs(imcoh.get_data()), axis=0), linewidth=2) +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Absolute connectivity (A.U.)") +fig.suptitle("Imaginary part of coherency") ############################################################################### @@ -148,9 +148,9 @@ fig, axis = plt.subplots(1, 1) axis.plot(mic.freqs, np.abs(mic.get_data()[0]), linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Absolute connectivity (A.U.)') -fig.suptitle('Maximised imaginary part of coherency') +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Absolute connectivity (A.U.)") +fig.suptitle("Maximised imaginary part of coherency") ############################################################################### @@ -178,14 +178,12 @@ fband_idx = [mic.freqs.index(freq) for freq in fband] # patterns have shape [seeds/targets x cons x channels x freqs (x times)] -patterns = np.array(mic.attrs['patterns']) -seed_pattern = patterns[0, :, :len(seeds)] -target_pattern = patterns[1, :, :len(targets)] +patterns = np.array(mic.attrs["patterns"]) +seed_pattern = patterns[0, :, : len(seeds)] +target_pattern = patterns[1, :, : len(targets)] # average across frequencies -seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], - axis=1) -target_pattern = np.mean(target_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], - axis=1) +seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0] : fband_idx[1] + 1], axis=1) +target_pattern = np.mean(target_pattern[0, :, fband_idx[0] : fband_idx[1] + 1], axis=1) # store the patterns for plotting seed_info = epochs.copy().pick(seed_names).info @@ -196,22 +194,38 @@ # plot the patterns fig, axes = plt.subplots(1, 4) seed_pattern.plot_topomap( - times=0, sensors='m.', units=dict(mag='A.U.'), cbar_fmt='%.1E', - axes=axes[0:2], time_format='', show=False) + times=0, + sensors="m.", + units=dict(mag="A.U."), + cbar_fmt="%.1E", + axes=axes[0:2], + time_format="", + show=False, +) target_pattern.plot_topomap( - times=0, sensors='m.', units=dict(mag='A.U.'), cbar_fmt='%.1E', - axes=axes[2:], time_format='', show=False) + times=0, + sensors="m.", + units=dict(mag="A.U."), + cbar_fmt="%.1E", + axes=axes[2:], + time_format="", + show=False, +) axes[0].set_position((0.1, 0.1, 0.35, 0.7)) axes[1].set_position((0.4, 0.3, 0.02, 0.3)) axes[2].set_position((0.5, 0.1, 0.35, 0.7)) axes[3].set_position((0.9, 0.3, 0.02, 0.3)) -axes[0].set_title('Seed spatial pattern\n13-18 Hz') -axes[2].set_title('Target spatial pattern\n13-18 Hz') +axes[0].set_title("Seed spatial pattern\n13-18 Hz") +axes[2].set_title("Target spatial pattern\n13-18 Hz") # plot the left hemisphere dipole example axes[0].plot( - [-0.01, -0.07], [-0.07, -0.03], color='lime', linewidth=2, - path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()]) + [-0.01, -0.07], + [-0.07, -0.03], + color="lime", + linewidth=2, + path_effects=[pe.Stroke(linewidth=4, foreground="k"), pe.Normal()], +) plt.show() @@ -257,13 +271,13 @@ fig, axis = plt.subplots(1, 1) axis.plot(mim.freqs, mim.get_data()[0], linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Absolute connectivity (A.U.)') -fig.suptitle('Multivariate interaction measure') +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Absolute connectivity (A.U.)") +fig.suptitle("Multivariate interaction measure") n_channels = len(seeds) + len(targets) normalised_mim = mim.get_data()[0] / n_channels -print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}') +print(f"Normalised MIM has a maximum value of {normalised_mim.max():.2f}") ############################################################################### @@ -291,18 +305,18 @@ indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) gim = spectral_connectivity_epochs( - epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, - verbose=False) + epochs, method="mim", indices=indices, fmin=5, fmax=30, rank=None, verbose=False +) fig, axis = plt.subplots(1, 1) axis.plot(gim.freqs, gim.get_data()[0], linewidth=2) -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Connectivity (A.U.)') -fig.suptitle('Global interaction measure') +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Connectivity (A.U.)") +fig.suptitle("Global interaction measure") n_channels = len(seeds) + len(targets) normalised_gim = gim.get_data()[0] / n_channels -print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}') +print(f"Normalised GIM has a maximum value of {normalised_gim.max():.2f}") ############################################################################### @@ -343,8 +357,13 @@ # %% (mic_red, mim_red) = spectral_connectivity_epochs( - epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, - rank=([25], [25])) + epochs, + method=["mic", "mim"], + indices=multivar_indices, + fmin=5, + fmax=30, + rank=([25], [25]), +) # subtract mean of scores for comparison mim_red_meansub = mim_red.get_data()[0] - mim_red.get_data()[0].mean() @@ -352,19 +371,16 @@ # compare standard and rank subspace-projected MIM fig, axis = plt.subplots(1, 1) -axis.plot(mim_red.freqs, mim_red_meansub, linewidth=2, - label='rank subspace (25) MIM') -axis.plot(mim.freqs, mim_meansub, linewidth=2, label='standard MIM') -axis.set_xlabel('Frequency (Hz)') -axis.set_ylabel('Mean-corrected connectivity (A.U.)') +axis.plot(mim_red.freqs, mim_red_meansub, linewidth=2, label="rank subspace (25) MIM") +axis.plot(mim.freqs, mim_meansub, linewidth=2, label="standard MIM") +axis.set_xlabel("Frequency (Hz)") +axis.set_ylabel("Mean-corrected connectivity (A.U.)") axis.legend() -fig.suptitle('Multivariate interaction measure (non-normalised)') +fig.suptitle("Multivariate interaction measure (non-normalised)") # no. channels equal with and without projecting to rank subspace for patterns -assert (patterns[0, 0].shape[0] == - np.array(mic_red.attrs['patterns'])[0, 0].shape[0]) -assert (patterns[1, 0].shape[0] == - np.array(mic_red.attrs['patterns'])[1, 0].shape[0]) +assert patterns[0, 0].shape[0] == np.array(mic_red.attrs["patterns"])[0, 0].shape[0] +assert patterns[1, 0].shape[0] == np.array(mic_red.attrs["patterns"])[1, 0].shape[0] ############################################################################### diff --git a/examples/mixed_source_space_connectivity.py b/examples/mixed_source_space_connectivity.py index 263ae7f2..0dcf4880 100644 --- a/examples/mixed_source_space_connectivity.py +++ b/examples/mixed_source_space_connectivity.py @@ -28,44 +28,52 @@ # Set directories data_path = sample.data_path() -subject = 'sample' -data_dir = op.join(data_path, 'MEG', subject) -subjects_dir = op.join(data_path, 'subjects') -bem_dir = op.join(subjects_dir, subject, 'bem') +subject = "sample" +data_dir = op.join(data_path, "MEG", subject) +subjects_dir = op.join(data_path, "subjects") +bem_dir = op.join(subjects_dir, subject, "bem") # Set file names -fname_aseg = op.join(subjects_dir, subject, 'mri', 'aseg.mgz') +fname_aseg = op.join(subjects_dir, subject, "mri", "aseg.mgz") -fname_model = op.join(bem_dir, '%s-5120-bem.fif' % subject) -fname_bem = op.join(bem_dir, '%s-5120-bem-sol.fif' % subject) +fname_model = op.join(bem_dir, "%s-5120-bem.fif" % subject) +fname_bem = op.join(bem_dir, "%s-5120-bem-sol.fif" % subject) -fname_raw = op.join(data_dir, 'sample_audvis_filt-0-40_raw.fif') -fname_trans = op.join(data_dir, 'sample_audvis_raw-trans.fif') -fname_cov = op.join(data_dir, 'ernoise-cov.fif') -fname_event = op.join(data_dir, 'sample_audvis_filt-0-40_raw-eve.fif') +fname_raw = op.join(data_dir, "sample_audvis_filt-0-40_raw.fif") +fname_trans = op.join(data_dir, "sample_audvis_raw-trans.fif") +fname_cov = op.join(data_dir, "ernoise-cov.fif") +fname_event = op.join(data_dir, "sample_audvis_filt-0-40_raw-eve.fif") # List of sub structures we are interested in. We select only the # sub structures we want to include in the source space -labels_vol = ['Left-Amygdala', - 'Left-Thalamus-Proper', - 'Left-Cerebellum-Cortex', - 'Brain-Stem', - 'Right-Amygdala', - 'Right-Thalamus-Proper', - 'Right-Cerebellum-Cortex'] +labels_vol = [ + "Left-Amygdala", + "Left-Thalamus-Proper", + "Left-Cerebellum-Cortex", + "Brain-Stem", + "Right-Amygdala", + "Right-Thalamus-Proper", + "Right-Cerebellum-Cortex", +] # Setup a surface-based source space, oct5 is not very dense (just used # to speed up this example; we recommend oct6 in actual analyses) -src = setup_source_space(subject, subjects_dir=subjects_dir, - spacing='oct5', add_dist=False) +src = setup_source_space( + subject, subjects_dir=subjects_dir, spacing="oct5", add_dist=False +) # Setup a volume source space # set pos=10.0 for speed, not very accurate; we recommend something smaller # like 5.0 in actual analyses: vol_src = setup_volume_source_space( - subject, mri=fname_aseg, pos=10.0, bem=fname_model, + subject, + mri=fname_aseg, + pos=10.0, + bem=fname_model, add_interpolator=False, # just for speed, usually use True - volume_label=labels_vol, subjects_dir=subjects_dir) + volume_label=labels_vol, + subjects_dir=subjects_dir, +) # Generate the mixed source space src += vol_src @@ -76,52 +84,61 @@ noise_cov = mne.read_cov(fname_cov) # compute the fwd matrix -fwd = make_forward_solution(raw.info, fname_trans, src, fname_bem, - mindist=5.0) # ignore sources<=5mm from innerskull +fwd = make_forward_solution( + raw.info, fname_trans, src, fname_bem, mindist=5.0 +) # ignore sources<=5mm from innerskull del src # Define epochs for left-auditory condition event_id, tmin, tmax = 1, -0.2, 0.5 reject = dict(mag=4e-12, grad=4000e-13, eog=150e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, - reject=reject, preload=False) +epochs = mne.Epochs(raw, events, event_id, tmin, tmax, reject=reject, preload=False) del raw # Compute inverse solution and for each epoch -snr = 1.0 # use smaller SNR for raw data -inv_method = 'dSPM' -parc = 'aparc' # the parcellation to use, e.g., 'aparc' 'aparc.a2009s' +snr = 1.0 # use smaller SNR for raw data +inv_method = "dSPM" +parc = "aparc" # the parcellation to use, e.g., 'aparc' 'aparc.a2009s' -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # Compute inverse operator inverse_operator = make_inverse_operator( - epochs.info, fwd, noise_cov, depth=None, fixed=False) + epochs.info, fwd, noise_cov, depth=None, fixed=False +) del fwd -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, inv_method, - pick_ori=None, return_generator=True) +stcs = apply_inverse_epochs( + epochs, inverse_operator, lambda2, inv_method, pick_ori=None, return_generator=True +) # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi -labels_parc = mne.read_labels_from_annot(subject, parc=parc, - subjects_dir=subjects_dir) +labels_parc = mne.read_labels_from_annot(subject, parc=parc, subjects_dir=subjects_dir) # Average the source estimates within each label of the cortical parcellation # and each sub-structure contained in the source space. # When mode = 'mean_flip', this option is used only for the cortical labels. -src = inverse_operator['src'] +src = inverse_operator["src"] label_ts = mne.extract_label_time_course( - stcs, labels_parc, src, mode='mean_flip', allow_empty=True, - return_generator=True) + stcs, labels_parc, src, mode="mean_flip", allow_empty=True, return_generator=True +) # We compute the connectivity in the alpha band and plot it using a circular # graph layout -fmin = 8. -fmax = 13. -sfreq = epochs.info['sfreq'] # the sampling frequency +fmin = 8.0 +fmax = 13.0 +sfreq = epochs.info["sfreq"] # the sampling frequency con = spectral_connectivity_epochs( - label_ts, method='pli', mode='multitaper', sfreq=sfreq, fmin=fmin, - fmax=fmax, faverage=True, mt_adaptive=True, n_jobs=1) + label_ts, + method="pli", + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + mt_adaptive=True, + n_jobs=1, +) # We create a list of Label containing also the sub structures labels_aseg = mne.get_volume_labels_from_src(src, subject, subjects_dir) @@ -132,8 +149,8 @@ # We reorder the labels based on their location in the left hemi label_names = [label.name for label in labels] -lh_labels = [name for name in label_names if name.endswith('lh')] -rh_labels = [name for name in label_names if name.endswith('rh')] +lh_labels = [name for name in label_names if name.endswith("lh")] +rh_labels = [name for name in label_names if name.endswith("rh")] # Get the y-location of the label label_ypos_lh = list() @@ -142,12 +159,12 @@ ypos = np.mean(labels[idx].pos[:, 1]) label_ypos_lh.append(ypos) try: - idx = label_names.index('Brain-Stem') + idx = label_names.index("Brain-Stem") except ValueError: pass else: ypos = np.mean(labels[idx].pos[:, 1]) - lh_labels.append('Brain-Stem') + lh_labels.append("Brain-Stem") label_ypos_lh.append(ypos) @@ -155,25 +172,33 @@ lh_labels = [label for (yp, label) in sorted(zip(label_ypos_lh, lh_labels))] # For the right hemi -rh_labels = [label[:-2] + 'rh' for label in lh_labels - if label != 'Brain-Stem' and label[:-2] + 'rh' in rh_labels] +rh_labels = [ + label[:-2] + "rh" + for label in lh_labels + if label != "Brain-Stem" and label[:-2] + "rh" in rh_labels +] # Save the plot order node_order = lh_labels[::-1] + rh_labels -node_angles = circular_layout(label_names, node_order, start_pos=90, - group_boundaries=[0, len(label_names) // 2]) +node_angles = circular_layout( + label_names, node_order, start_pos=90, group_boundaries=[0, len(label_names) // 2] +) # Plot the graph using node colors from the FreeSurfer parcellation. We only # show the 300 strongest connections. -conmat = con.get_data(output='dense')[:, :, 0] -fig, ax = plt.subplots(figsize=(8, 8), facecolor='black', - subplot_kw=dict(polar=True)) -plot_connectivity_circle(conmat, label_names, n_lines=300, - node_angles=node_angles, node_colors=node_colors, - title='All-to-All Connectivity left-Auditory ' - 'Condition (PLI)', ax=ax) +conmat = con.get_data(output="dense")[:, :, 0] +fig, ax = plt.subplots(figsize=(8, 8), facecolor="black", subplot_kw=dict(polar=True)) +plot_connectivity_circle( + conmat, + label_names, + n_lines=300, + node_angles=node_angles, + node_colors=node_colors, + title="All-to-All Connectivity left-Auditory " "Condition (PLI)", + ax=ax, +) fig.tight_layout() ############################################################################### diff --git a/examples/mne_inverse_coherence_epochs.py b/examples/mne_inverse_coherence_epochs.py index 65b775d4..0a03c862 100644 --- a/examples/mne_inverse_coherence_epochs.py +++ b/examples/mne_inverse_coherence_epochs.py @@ -15,8 +15,7 @@ import mne from mne.datasets import sample -from mne.minimum_norm import (apply_inverse, apply_inverse_epochs, - read_inverse_operator) +from mne.minimum_norm import apply_inverse, apply_inverse_epochs, read_inverse_operator from mne_connectivity import seed_target_indices, spectral_connectivity_epochs print(__doc__) @@ -30,12 +29,12 @@ # compute the event-related coherence. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -fname_inv = data_path / 'MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw.fif' -fname_event = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' -label_name_lh = 'Aud-lh' -fname_label_lh = data_path / f'MEG/sample/labels/{label_name_lh}.label' +subjects_dir = data_path / "subjects" +fname_inv = data_path / "MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = data_path / "MEG/sample/sample_audvis_filt-0-40_raw.fif" +fname_event = data_path / "MEG/sample/sample_audvis_filt-0-40_raw-eve.fif" +label_name_lh = "Aud-lh" +fname_label_lh = data_path / f"MEG/sample/labels/{label_name_lh}.label" event_id, tmin, tmax = 1, -0.2, 0.5 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) @@ -47,16 +46,24 @@ events = mne.read_events(fname_event) # Add a bad channel. -raw.info['bads'] += ['MEG 2443'] +raw.info["bads"] += ["MEG 2443"] # pick MEG channels. -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, exclude="bads" +) # Read epochs. -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), - reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) ############################################################################### # Choose channels for coherence estimation @@ -67,16 +74,15 @@ # connectivity computation. snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 evoked = epochs.average() -stc = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori="normal") +stc = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori="normal") # Restrict the source estimate to the label in the left auditory cortex. stc_label = stc.in_label(label_lh) # Find number and index of vertex with most power. -src_pow = np.sum(stc_label.data ** 2, axis=1) +src_pow = np.sum(stc_label.data**2, axis=1) seed_vertno = stc_label.vertices[0][np.argmax(src_pow)] seed_idx = np.searchsorted(stc.vertices[0], seed_vertno) # index in orig stc @@ -90,9 +96,10 @@ # compute the coherence without having to keep all source estimates in memory. snr = 1.0 # use lower SNR for single epochs -lambda2 = 1.0 / snr ** 2 -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, - pick_ori="normal", return_generator=True) +lambda2 = 1.0 / snr**2 +stcs = apply_inverse_epochs( + epochs, inverse_operator, lambda2, method, pick_ori="normal", return_generator=True +) ############################################################################### # Compute the coherence between sources @@ -107,18 +114,26 @@ # By using faverage=True, we directly average the coherence in the alpha and # beta band, i.e., we will only get 2 frequency bins. -fmin = (8., 13.) -fmax = (13., 30.) -sfreq = raw.info['sfreq'] # the sampling frequency +fmin = (8.0, 13.0) +fmax = (13.0, 30.0) +sfreq = raw.info["sfreq"] # the sampling frequency coh = spectral_connectivity_epochs( - stcs, method='coh', mode='fourier', indices=indices, - sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True, n_jobs=1) + stcs, + method="coh", + mode="fourier", + indices=indices, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + n_jobs=1, +) freqs = coh.freqs -print('Frequencies in Hz over which coherence was averaged for alpha: ') +print("Frequencies in Hz over which coherence was averaged for alpha: ") print(freqs[0]) -print('Frequencies in Hz over which coherence was averaged for beta: ') +print("Frequencies in Hz over which coherence was averaged for beta: ") print(freqs[1]) ############################################################################### @@ -136,12 +151,20 @@ tmin = np.mean(freqs[0]) tstep = np.mean(freqs[1]) - tmin coh_stc = mne.SourceEstimate( - coh.get_data(), vertices=stc.vertices, tmin=1e-3 * tmin, - tstep=1e-3 * tstep, subject='sample') + coh.get_data(), + vertices=stc.vertices, + tmin=1e-3 * tmin, + tstep=1e-3 * tstep, + subject="sample", +) # Now we can visualize the coherence using the plot method. -brain = coh_stc.plot('sample', 'inflated', 'both', - time_label='Coherence %0.1f Hz', - subjects_dir=subjects_dir, - clim=dict(kind='value', lims=(0.25, 0.4, 0.65))) -brain.show_view('lateral') +brain = coh_stc.plot( + "sample", + "inflated", + "both", + time_label="Coherence %0.1f Hz", + subjects_dir=subjects_dir, + clim=dict(kind="value", lims=(0.25, 0.4, 0.65)), +) +brain.show_view("lateral") diff --git a/examples/mne_inverse_connectivity_spectrum.py b/examples/mne_inverse_connectivity_spectrum.py index a25636bf..7321f032 100644 --- a/examples/mne_inverse_connectivity_spectrum.py +++ b/examples/mne_inverse_connectivity_spectrum.py @@ -21,10 +21,10 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -fname_inv = data_path / 'MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw.fif' -fname_event = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' +subjects_dir = data_path / "subjects" +fname_inv = data_path / "MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = data_path / "MEG/sample/sample_audvis_filt-0-40_raw.fif" +fname_event = data_path / "MEG/sample/sample_audvis_filt-0-40_raw-eve.fif" # Load data inverse_operator = read_inverse_operator(fname_inv) @@ -32,46 +32,64 @@ events = mne.read_events(fname_event) # Add a bad channel -raw.info['bads'] += ['MEG 2443'] +raw.info["bads"] += ["MEG 2443"] # Pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, exclude="bads" +) # Define epochs for left-auditory condition event_id, tmin, tmax = 1, -0.2, 0.5 -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # Compute inverse solution and for each epoch. By using "return_generator=True" # stcs will be a generator object instead of a list. snr = 1.0 # use lower SNR for single epochs -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, - pick_ori="normal", return_generator=True) +stcs = apply_inverse_epochs( + epochs, inverse_operator, lambda2, method, pick_ori="normal", return_generator=True +) # Read some labels -names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh'] -labels = [mne.read_label(data_path / f'MEG/sample/labels/{name}.label') - for name in names] +names = ["Aud-lh", "Aud-rh", "Vis-lh", "Vis-rh"] +labels = [ + mne.read_label(data_path / f"MEG/sample/labels/{name}.label") for name in names +] # Average the source estimates within each label using sign-flips to reduce # signal cancellations, also here we return a generator -src = inverse_operator['src'] -label_ts = mne.extract_label_time_course(stcs, labels, src, mode='mean_flip', - return_generator=True) +src = inverse_operator["src"] +label_ts = mne.extract_label_time_course( + stcs, labels, src, mode="mean_flip", return_generator=True +) -fmin, fmax = 7.5, 40. -sfreq = raw.info['sfreq'] # the sampling frequency +fmin, fmax = 7.5, 40.0 +sfreq = raw.info["sfreq"] # the sampling frequency con = spectral_connectivity_epochs( - label_ts, method='wpli2_debiased', mode='multitaper', sfreq=sfreq, - fmin=fmin, fmax=fmax, mt_adaptive=True, n_jobs=1) + label_ts, + method="wpli2_debiased", + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + mt_adaptive=True, + n_jobs=1, +) freqs = con.freqs -n_rows, n_cols = con.get_data(output='dense').shape[:2] +n_rows, n_cols = con.get_data(output="dense").shape[:2] fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=True) for i in range(n_rows): for j in range(i + 1): @@ -79,8 +97,8 @@ axes[i, j].set_axis_off() continue - axes[i, j].plot(freqs, con.get_data(output='dense')[i, j, :]) - axes[j, i].plot(freqs, con.get_data(output='dense')[i, j, :]) + axes[i, j].plot(freqs, con.get_data(output="dense")[i, j, :]) + axes[j, i].plot(freqs, con.get_data(output="dense")[i, j, :]) if j == 0: axes[i, j].set_ylabel(names[i]) @@ -92,7 +110,7 @@ # Show band limits for f in [8, 12, 18, 35]: - axes[i, j].axvline(f, color='k') - axes[j, i].axvline(f, color='k') + axes[i, j].axvline(f, color="k") + axes[j, i].axvline(f, color="k") plt.tight_layout() plt.show() diff --git a/examples/mne_inverse_envelope_correlation.py b/examples/mne_inverse_envelope_correlation.py index b05ef625..17454683 100644 --- a/examples/mne_inverse_envelope_correlation.py +++ b/examples/mne_inverse_envelope_correlation.py @@ -50,23 +50,24 @@ from mne.preprocessing import compute_proj_ecg, compute_proj_eog data_path = mne.datasets.brainstorm.bst_resting.data_path() -subjects_dir = op.join(data_path, 'subjects') -subject = 'bst_resting' -trans = op.join(data_path, 'MEG', 'bst_resting', 'bst_resting-trans.fif') -src = op.join(subjects_dir, subject, 'bem', subject + '-oct-6-src.fif') -bem = op.join(subjects_dir, subject, 'bem', subject + '-5120-bem-sol.fif') -raw_fname = op.join(data_path, 'MEG', 'bst_resting', - 'subj002_spontaneous_20111102_01_AUX.ds') +subjects_dir = op.join(data_path, "subjects") +subject = "bst_resting" +trans = op.join(data_path, "MEG", "bst_resting", "bst_resting-trans.fif") +src = op.join(subjects_dir, subject, "bem", subject + "-oct-6-src.fif") +bem = op.join(subjects_dir, subject, "bem", subject + "-5120-bem-sol.fif") +raw_fname = op.join( + data_path, "MEG", "bst_resting", "subj002_spontaneous_20111102_01_AUX.ds" +) ############################################################################## # Here we do some things in the name of speed, such as crop (which will # hurt SNR) and downsample. Then we compute SSP projectors and apply them. -raw = mne.io.read_raw_ctf(raw_fname, verbose='error') +raw = mne.io.read_raw_ctf(raw_fname, verbose="error") raw.crop(0, 60).pick_types(meg=True, eeg=False).load_data().resample(80) raw.apply_gradient_compensation(3) projs_ecg, _ = compute_proj_ecg(raw, n_grad=1, n_mag=2) -projs_eog, _ = compute_proj_eog(raw, n_grad=1, n_mag=2, ch_name='MLT31-4407') +projs_eog, _ = compute_proj_eog(raw, n_grad=1, n_mag=2, ch_name="MLT31-4407") raw.add_proj(projs_ecg + projs_eog) raw.apply_proj() raw.filter(0.1, None) # this helps with symmetric orthogonalization later @@ -85,12 +86,13 @@ ############################################################################## # Now we create epochs and prepare to band-pass filter them. -duration = 10. +duration = 10.0 events = mne.make_fixed_length_events(raw, duration=duration) -tmax = duration - 1. / raw.info['sfreq'] -epochs = mne.Epochs(raw, events=events, tmin=0, tmax=tmax, - baseline=None, reject=dict(mag=20e-13)) -sfreq = epochs.info['sfreq'] +tmax = duration - 1.0 / raw.info["sfreq"] +epochs = mne.Epochs( + raw, events=events, tmin=0, tmax=tmax, baseline=None, reject=dict(mag=20e-13) +) +sfreq = epochs.info["sfreq"] del raw, projs_ecg, projs_eog # %% @@ -99,12 +101,13 @@ # sphinx_gallery_thumbnail_number = 2 -labels = mne.read_labels_from_annot(subject, 'aparc_sub', - subjects_dir=subjects_dir) -stcs = apply_inverse_epochs(epochs, inv, lambda2=1. / 9., pick_ori='normal', - return_generator=True) +labels = mne.read_labels_from_annot(subject, "aparc_sub", subjects_dir=subjects_dir) +stcs = apply_inverse_epochs( + epochs, inv, lambda2=1.0 / 9.0, pick_ori="normal", return_generator=True +) label_ts = mne.extract_label_time_course( - stcs, labels, inv['src'], return_generator=False) + stcs, labels, inv["src"], return_generator=False +) del stcs @@ -114,34 +117,40 @@ def bp_gen(label_ts): yield mne.filter.filter_data(ts, sfreq, 14, 30) -corr_obj = envelope_correlation( - bp_gen(label_ts), orthogonalize='pairwise') +corr_obj = envelope_correlation(bp_gen(label_ts), orthogonalize="pairwise") corr = corr_obj.combine() -corr = corr.get_data(output='dense')[:, :, 0] +corr = corr.get_data(output="dense")[:, :, 0] def plot_corr(corr, title): fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True) - ax.imshow(corr, cmap='viridis', clim=np.percentile(corr, [5, 95])) + ax.imshow(corr, cmap="viridis", clim=np.percentile(corr, [5, 95])) fig.suptitle(title) -plot_corr(corr, 'Pairwise') +plot_corr(corr, "Pairwise") def plot_degree(corr, title): threshold_prop = 0.15 # percentage of strongest edges to keep in the graph degree = mne_connectivity.degree(corr, threshold_prop=threshold_prop) stc = mne.labels_to_stc(labels, degree) - stc = stc.in_label(mne.Label(inv['src'][0]['vertno'], hemi='lh') + - mne.Label(inv['src'][1]['vertno'], hemi='rh')) + stc = stc.in_label( + mne.Label(inv["src"][0]["vertno"], hemi="lh") + + mne.Label(inv["src"][1]["vertno"], hemi="rh") + ) return stc.plot( - clim=dict(kind='percent', lims=[75, 85, 95]), colormap='gnuplot', - subjects_dir=subjects_dir, views='dorsal', hemi='both', - smoothing_steps=25, time_label=title) + clim=dict(kind="percent", lims=[75, 85, 95]), + colormap="gnuplot", + subjects_dir=subjects_dir, + views="dorsal", + hemi="both", + smoothing_steps=25, + time_label=title, + ) -brain = plot_degree(corr, 'Beta (pairwise, aparc_sub)') +brain = plot_degree(corr, "Beta (pairwise, aparc_sub)") # %% # Do symmetric-orthogonalized envelope correlation @@ -151,26 +160,28 @@ def plot_degree(corr, title): # relative to one another. ``'aparc_sub'`` has over 400 labels, so here we # use ``'aparc.a2009s'``, which has fewer than 200. -labels = mne.read_labels_from_annot(subject, 'aparc.a2009s', - subjects_dir=subjects_dir) -stcs = apply_inverse_epochs(epochs, inv, lambda2=1. / 9., pick_ori='normal', - return_generator=True) +labels = mne.read_labels_from_annot(subject, "aparc.a2009s", subjects_dir=subjects_dir) +stcs = apply_inverse_epochs( + epochs, inv, lambda2=1.0 / 9.0, pick_ori="normal", return_generator=True +) label_ts = mne.extract_label_time_course( - stcs, labels, inv['src'], return_generator=True) + stcs, labels, inv["src"], return_generator=True +) del stcs, epochs label_ts_orth = mne_connectivity.envelope.symmetric_orth(label_ts) corr_obj = envelope_correlation( # already orthogonalized earlier - bp_gen(label_ts_orth), orthogonalize=False) + bp_gen(label_ts_orth), orthogonalize=False +) # average over epochs, take absolute value, and plot corr = corr_obj.combine() -corr = corr.get_data(output='dense')[:, :, 0] -corr.flat[::corr.shape[0] + 1] = 0 # zero out the diagonal +corr = corr.get_data(output="dense")[:, :, 0] +corr.flat[:: corr.shape[0] + 1] = 0 # zero out the diagonal corr = np.abs(corr) -plot_corr(corr, 'Symmetric') -plot_degree(corr, 'Beta (symmetric, aparc.a2009s)') +plot_corr(corr, "Symmetric") +plot_degree(corr, "Beta (symmetric, aparc.a2009s)") # %% # References # ---------- diff --git a/examples/mne_inverse_envelope_correlation_volume.py b/examples/mne_inverse_envelope_correlation_volume.py index 7f81d141..87093e05 100644 --- a/examples/mne_inverse_envelope_correlation_volume.py +++ b/examples/mne_inverse_envelope_correlation_volume.py @@ -23,23 +23,24 @@ from mne_connectivity import envelope_correlation data_path = mne.datasets.brainstorm.bst_resting.data_path() -subjects_dir = op.join(data_path, 'subjects') -subject = 'bst_resting' -trans = op.join(data_path, 'MEG', 'bst_resting', 'bst_resting-trans.fif') -bem = op.join(subjects_dir, subject, 'bem', subject + '-5120-bem-sol.fif') -raw_fname = op.join(data_path, 'MEG', 'bst_resting', - 'subj002_spontaneous_20111102_01_AUX.ds') -crop_to = 60. +subjects_dir = op.join(data_path, "subjects") +subject = "bst_resting" +trans = op.join(data_path, "MEG", "bst_resting", "bst_resting-trans.fif") +bem = op.join(subjects_dir, subject, "bem", subject + "-5120-bem-sol.fif") +raw_fname = op.join( + data_path, "MEG", "bst_resting", "subj002_spontaneous_20111102_01_AUX.ds" +) +crop_to = 60.0 ############################################################################## # Here we do some things in the name of speed, such as crop (which will # hurt SNR) and downsample. Then we compute SSP projectors and apply them. -raw = mne.io.read_raw_ctf(raw_fname, verbose='error') +raw = mne.io.read_raw_ctf(raw_fname, verbose="error") raw.crop(0, crop_to).pick_types(meg=True, eeg=False).load_data().resample(80) raw.apply_gradient_compensation(3) projs_ecg, _ = compute_proj_ecg(raw, n_grad=1, n_mag=2) -projs_eog, _ = compute_proj_eog(raw, n_grad=1, n_mag=2, ch_name='MLT31-4407') +projs_eog, _ = compute_proj_eog(raw, n_grad=1, n_mag=2, ch_name="MLT31-4407") raw.add_proj(projs_ecg + projs_eog) raw.apply_proj() cov = mne.compute_raw_covariance(raw) # compute before band-pass of interest @@ -48,9 +49,16 @@ # Now we band-pass filter our data and create epochs. raw.filter(14, 30) -events = mne.make_fixed_length_events(raw, duration=5.) -epochs = mne.Epochs(raw, events=events, tmin=0, tmax=5., - baseline=None, reject=dict(mag=8e-13), preload=True) +events = mne.make_fixed_length_events(raw, duration=5.0) +epochs = mne.Epochs( + raw, + events=events, + tmin=0, + tmax=5.0, + baseline=None, + reject=dict(mag=8e-13), + preload=True, +) data_cov = mne.compute_covariance(epochs) del raw, projs_ecg, projs_eog @@ -60,12 +68,14 @@ # This source space is really far too coarse, but we do this for speed # considerations here -pos = 15. # 1.5 cm is very broad, done here for speed! -src = mne.setup_volume_source_space('bst_resting', pos, bem=bem, - subjects_dir=subjects_dir, verbose=True) +pos = 15.0 # 1.5 cm is very broad, done here for speed! +src = mne.setup_volume_source_space( + "bst_resting", pos, bem=bem, subjects_dir=subjects_dir, verbose=True +) fwd = mne.make_forward_solution(epochs.info, trans, src, bem) -filters = make_lcmv(epochs.info, fwd, data_cov, 0.05, cov, - pick_ori='max-power', weight_norm='nai') +filters = make_lcmv( + epochs.info, fwd, data_cov, 0.05, cov, pick_ori="max-power", weight_norm="nai" +) del fwd, data_cov, cov ############################################################################## @@ -85,10 +95,14 @@ # ------------------------------ degree = mne_connectivity.degree(corr, 0.15) -stc = mne.VolSourceEstimate(degree, [src[0]['vertno']], 0, 1, 'bst_resting') +stc = mne.VolSourceEstimate(degree, [src[0]["vertno"]], 0, 1, "bst_resting") brain = stc.plot( - src, clim=dict(kind='percent', lims=[75, 85, 95]), colormap='gnuplot', - subjects_dir=subjects_dir, mode='glass_brain') + src, + clim=dict(kind="percent", lims=[75, 85, 95]), + colormap="gnuplot", + subjects_dir=subjects_dir, + mode="glass_brain", +) ############################################################################## # References diff --git a/examples/mne_inverse_label_connectivity.py b/examples/mne_inverse_label_connectivity.py index 00704e8f..ef24b87a 100644 --- a/examples/mne_inverse_label_connectivity.py +++ b/examples/mne_inverse_label_connectivity.py @@ -34,10 +34,10 @@ # the sample MEG data provided with MNE. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -fname_inv = data_path / 'MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw.fif' -fname_event = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' +subjects_dir = data_path / "subjects" +fname_inv = data_path / "MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = data_path / "MEG/sample/sample_audvis_filt-0-40_raw.fif" +fname_event = data_path / "MEG/sample/sample_audvis_filt-0-40_raw-eve.fif" # Load data inverse_operator = read_inverse_operator(fname_inv) @@ -45,17 +45,25 @@ events = mne.read_events(fname_event) # Add a bad channel -raw.info['bads'] += ['MEG 2443'] +raw.info["bads"] += ["MEG 2443"] # Pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, exclude="bads" +) # Define epochs for left-auditory condition event_id, tmin, tmax = 1, -0.2, 0.5 -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) ############################################################################### # Compute inverse solutions and their connectivity @@ -82,35 +90,44 @@ # Compute inverse solution and for each epoch. By using "return_generator=True" # stcs will be a generator object instead of a list. snr = 1.0 # use lower SNR for single epochs -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, - pick_ori="normal", return_generator=True) +stcs = apply_inverse_epochs( + epochs, inverse_operator, lambda2, method, pick_ori="normal", return_generator=True +) # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi -labels = mne.read_labels_from_annot('sample', parc='aparc', - subjects_dir=subjects_dir) +labels = mne.read_labels_from_annot("sample", parc="aparc", subjects_dir=subjects_dir) label_colors = [label.color for label in labels] # Average the source estimates within each label using sign-flips to reduce # signal cancellations, also here we return a generator -src = inverse_operator['src'] +src = inverse_operator["src"] label_ts = mne.extract_label_time_course( - stcs, labels, src, mode='mean_flip', return_generator=True) + stcs, labels, src, mode="mean_flip", return_generator=True +) -fmin = 8. -fmax = 13. -sfreq = raw.info['sfreq'] # the sampling frequency -con_methods = ['pli', 'wpli2_debiased', 'ciplv'] +fmin = 8.0 +fmax = 13.0 +sfreq = raw.info["sfreq"] # the sampling frequency +con_methods = ["pli", "wpli2_debiased", "ciplv"] con = spectral_connectivity_epochs( - label_ts, method=con_methods, mode='multitaper', sfreq=sfreq, fmin=fmin, - fmax=fmax, faverage=True, mt_adaptive=True, n_jobs=1) + label_ts, + method=con_methods, + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + mt_adaptive=True, + n_jobs=1, +) # con is a 3D array, get the connectivity for the first (and only) freq. band # for each method con_res = dict() for method, c in zip(con_methods, con): - con_res[method] = c.get_data(output='dense')[:, :, 0] + con_res[method] = c.get_data(output="dense")[:, :, 0] ############################################################################### # Make a connectivity plot @@ -121,7 +138,7 @@ # First, we reorder the labels based on their location in the left hemi label_names = [label.name for label in labels] -lh_labels = [name for name in label_names if name.endswith('lh')] +lh_labels = [name for name in label_names if name.endswith("lh")] # Get the y-location of the label label_ypos = list() @@ -134,24 +151,29 @@ lh_labels = [label for (yp, label) in sorted(zip(label_ypos, lh_labels))] # For the right hemi -rh_labels = [label[:-2] + 'rh' for label in lh_labels] +rh_labels = [label[:-2] + "rh" for label in lh_labels] # Save the plot order and create a circular layout node_order = list() node_order.extend(lh_labels[::-1]) # reverse the order node_order.extend(rh_labels) -node_angles = circular_layout(label_names, node_order, start_pos=90, - group_boundaries=[0, len(label_names) / 2]) +node_angles = circular_layout( + label_names, node_order, start_pos=90, group_boundaries=[0, len(label_names) / 2] +) # Plot the graph using node colors from the FreeSurfer parcellation. We only # show the 300 strongest connections. -fig, ax = plt.subplots(figsize=(8, 8), facecolor='black', - subplot_kw=dict(polar=True)) -plot_connectivity_circle(con_res['pli'], label_names, n_lines=300, - node_angles=node_angles, node_colors=label_colors, - title='All-to-All Connectivity left-Auditory ' - 'Condition (PLI)', ax=ax) +fig, ax = plt.subplots(figsize=(8, 8), facecolor="black", subplot_kw=dict(polar=True)) +plot_connectivity_circle( + con_res["pli"], + label_names, + n_lines=300, + node_angles=node_angles, + node_colors=label_colors, + title="All-to-All Connectivity left-Auditory " "Condition (PLI)", + ax=ax, +) fig.tight_layout() ############################################################################### @@ -161,14 +183,22 @@ # We can also assign these connectivity plots to axes in a figure. Below we'll # show the connectivity plot using two different connectivity methods. -fig, axes = plt.subplots(1, 3, figsize=(8, 4), facecolor='black', - subplot_kw=dict(polar=True)) -no_names = [''] * len(label_names) +fig, axes = plt.subplots( + 1, 3, figsize=(8, 4), facecolor="black", subplot_kw=dict(polar=True) +) +no_names = [""] * len(label_names) for ax, method in zip(axes, con_methods): - plot_connectivity_circle(con_res[method], no_names, n_lines=300, - node_angles=node_angles, node_colors=label_colors, - title=method, padding=0, fontsize_colorbar=6, - ax=ax) + plot_connectivity_circle( + con_res[method], + no_names, + n_lines=300, + node_angles=node_angles, + node_colors=label_colors, + title=method, + padding=0, + fontsize_colorbar=6, + ax=ax, + ) ############################################################################### diff --git a/examples/mne_inverse_psi_visual.py b/examples/mne_inverse_psi_visual.py index 8043039f..d69fa8fb 100644 --- a/examples/mne_inverse_psi_visual.py +++ b/examples/mne_inverse_psi_visual.py @@ -29,11 +29,11 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -fname_inv = data_path / 'MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw.fif' -fname_event = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' -fname_label = data_path / 'MEG/sample/labels/Vis-lh.label' +subjects_dir = data_path / "subjects" +fname_inv = data_path / "MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = data_path / "MEG/sample/sample_audvis_filt-0-40_raw.fif" +fname_event = data_path / "MEG/sample/sample_audvis_filt-0-40_raw-eve.fif" +fname_label = data_path / "MEG/sample/labels/Vis-lh.label" event_id, tmin, tmax = 4, -0.2, 0.5 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) @@ -44,71 +44,92 @@ events = mne.read_events(fname_event) # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # Compute inverse solution and for each epoch. Note that since we are passing # the output to both extract_label_time_course and the phase_slope_index # functions, we have to use "return_generator=False", since it is only possible # to iterate over generators once. snr = 1.0 # use lower SNR for single epochs -lambda2 = 1.0 / snr ** 2 -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, - pick_ori="normal", return_generator=True) +lambda2 = 1.0 / snr**2 +stcs = apply_inverse_epochs( + epochs, inverse_operator, lambda2, method, pick_ori="normal", return_generator=True +) # Now, we generate seed time series by averaging the activity in the left # visual corex label = mne.read_label(fname_label) -src = inverse_operator['src'] # the source space used -seed_ts = mne.extract_label_time_course(stcs, label, src, mode='mean_flip', - verbose='error') +src = inverse_operator["src"] # the source space used +seed_ts = mne.extract_label_time_course( + stcs, label, src, mode="mean_flip", verbose="error" +) # Combine the seed time course with the source estimates. There will be a total # of 7500 signals: # index 0: time course extracted from label # index 1..7499: dSPM source space time courses -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, - pick_ori="normal", return_generator=True) +stcs = apply_inverse_epochs( + epochs, inverse_operator, lambda2, method, pick_ori="normal", return_generator=True +) comb_ts = list(zip(seed_ts, stcs)) # Construct indices to estimate connectivity between the label time course # and all source space time courses -vertices = [src[i]['vertno'] for i in range(2)] +vertices = [src[i]["vertno"] for i in range(2)] n_signals_tot = 1 + len(vertices[0]) + len(vertices[1]) indices = seed_target_indices([0], np.arange(1, n_signals_tot)) # Compute the PSI in the frequency range 10Hz-20Hz. We exclude the baseline # period from the connectivity estimation. -fmin = 10. -fmax = 20. -tmin_con = 0. -sfreq = epochs.info['sfreq'] # the sampling frequency +fmin = 10.0 +fmax = 20.0 +tmin_con = 0.0 +sfreq = epochs.info["sfreq"] # the sampling frequency psi = phase_slope_index( - comb_ts, mode='multitaper', indices=indices, sfreq=sfreq, - fmin=fmin, fmax=fmax, tmin=tmin_con) + comb_ts, + mode="multitaper", + indices=indices, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + tmin=tmin_con, +) # Generate a SourceEstimate with the PSI. This is simple since we used a single # seed (inspect the indices variable to see how the PSI scores are arranged in # the output) psi_stc = mne.SourceEstimate( - psi.get_data(), vertices=vertices, tmin=0, tstep=1, subject='sample') + psi.get_data(), vertices=vertices, tmin=0, tstep=1, subject="sample" +) # Now we can visualize the PSI using the :meth:`~mne.SourceEstimate.plot` # method. We use a custom colormap to show signed values v_max = np.max(np.abs(psi.get_data())) -brain = psi_stc.plot(surface='inflated', hemi='lh', - time_label='Phase Slope Index (PSI)', - subjects_dir=subjects_dir, - clim=dict(kind='percent', pos_lims=(95, 97.5, 100))) -brain.show_view('medial') -brain.add_label(str(fname_label), color='green', alpha=0.7) +brain = psi_stc.plot( + surface="inflated", + hemi="lh", + time_label="Phase Slope Index (PSI)", + subjects_dir=subjects_dir, + clim=dict(kind="percent", pos_lims=(95, 97.5, 100)), +) +brain.show_view("medial") +brain.add_label(str(fname_label), color="green", alpha=0.7) ############################################################################### # References diff --git a/examples/sensor_connectivity.py b/examples/sensor_connectivity.py index f49ea27c..934181a9 100644 --- a/examples/sensor_connectivity.py +++ b/examples/sensor_connectivity.py @@ -24,38 +24,52 @@ ############################################################################### # Set parameters data_path = sample.data_path() -raw_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_filt-0-40_raw.fif') -event_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_filt-0-40_raw-eve.fif') +raw_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif") +event_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif") # Setup for reading the raw data raw = mne.io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) # Add a bad channel -raw.info['bads'] += ['MEG 2443'] +raw.info["bads"] += ["MEG 2443"] # Pick MEG gradiometers -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) # Create epochs for the visual condition event_id, tmin, tmax = 3, -0.2, 1.5 # need a long enough epoch for 5 cycles -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), +) # Compute connectivity for band containing the evoked response. # We exclude the baseline period: -fmin, fmax = 4., 9. -sfreq = raw.info['sfreq'] # the sampling frequency +fmin, fmax = 4.0, 9.0 +sfreq = raw.info["sfreq"] # the sampling frequency tmin = 0.0 # exclude the baseline period -epochs.load_data().pick_types(meg='grad') # just keep MEG and no EOG now +epochs.load_data().pick_types(meg="grad") # just keep MEG and no EOG now con = spectral_connectivity_epochs( - epochs, method='pli', mode='multitaper', sfreq=sfreq, fmin=fmin, fmax=fmax, - faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1) + epochs, + method="pli", + mode="multitaper", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + faverage=True, + tmin=tmin, + mt_adaptive=False, + n_jobs=1, +) # Now, visualize the connectivity in 3D: -plot_sensors_connectivity( - epochs.info, - con.get_data(output='dense')[:, :, 0]) +plot_sensors_connectivity(epochs.info, con.get_data(output="dense")[:, :, 0]) diff --git a/pyproject.toml b/pyproject.toml index f0c5b798..70aad107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,15 +90,16 @@ gui = [ 'vtk', ] style = [ + "pre-commit", 'black', 'codespell', 'isort', 'pydocstyle', 'pydocstyle[toml]', + 'rstcheck', 'ruff', 'toml-sort', 'yamllint', - "pre-commit", ] test = [ 'joblib', @@ -176,6 +177,51 @@ addopts = '--durations 20 --junit-xml=junit-results.xml -v --color=yes' junit_family = 'xunit2' minversion = '6.0' +[tool.rstcheck] +ignore_directives = [ + "autoclass", + "autofunction", + "automodule", + "autosummary", + "bibliography", + "cssclass", + "currentmodule", + "dropdown", + "footbibliography", + "glossary", + "graphviz", + "grid", + "highlight", + "minigallery", + "rst-class", + "tab-set", + "tabularcolumns", + "toctree", +] +ignore_messages = "^.*(Unknown target name|Undefined substitution referenced)[^`]*$" +ignore_roles = [ + "attr", + "class", + "doc", + "eq", + "exc", + "file", + "footcite", + "footcite:t", + "func", + "gh", + "kbd", + "meth", + "mod", + "newcontrib", + "pr", + "py:mod", + "ref", + "samp", + "term", +] +report_level = "WARNING" + [tool.ruff] extend-exclude = [ 'benchmarks', @@ -201,3 +247,8 @@ include = ['mne_connectivity*'] all = true ignore_case = true trailing_comma_inline_array = true + +[tool.codespell] +ignore-words = ".codespellignore" +builtin = "clear,rare,informal,names,usage" +skip = "doc/references.bib" \ No newline at end of file