Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Multiple improvements to spectral_connectivity_time: ciPLV, and efficient computation of multiple metrics #115

Merged
merged 34 commits into from
Jan 10, 2023

Conversation

ruuskas
Copy link
Contributor

@ruuskas ruuskas commented Nov 25, 2022

PR Description

This PR addresses improvements mentioned in #104. I'm reposting this with the Git log now cleaned up.

  • Add ciPLV metric.
  • Compute all connectivity measures at once, avoiding the repeated computation of the cross-spectral density.
  • Reinstitute the option to specify frequencies of interest also when using the multitaper method. The reasoning is that users should have control over the sparsity of the frequency grid.
  • Add the option to use parts of each epoch as padding. This allows for the mitigation of edge effects in the time-frequency decomposition. The edges are removed after time-frequency decomposition, but before connectivity computation.
  • Improve the function description to better describe the objective.

Merge checklist

Maintainer, please confirm the following before merging:

  • All comments resolved
  • This is not your own PR
  • All CIs are happy
  • PR title starts with [MRG]
  • whats_new.rst is updated
  • PR description includes phrase "closes <#issue-number>"

Add the corrected imaginary Phase-Locking-Value into the list of
available connectivity metrics.
All connectivity measures are now computed with only a single
computation of pairwise cross spectrum.
In some scenarios, users might want to specify the frequencies for
time-frequency decomposition also when using multitapering. These
changes allow users to specify the 'freqs' parameter to override the
automatically determined frequencies.
This adds the option to use the edges of the signal at each epoch as
padding. The purpose of this is to avoid edge effects generated by the
time-frequency transformation methods.
Sym is not a parameter of dpss_windows. (But is one of the underlying
scipy.signal.dpss)
This change will skip the rendering of the connectivity computation
progress bar if the logging level is not DEBUG. This is in line with
MNE-Python, where progress bars are not shown at INFO or higher logging
levels. Rendering the progress bar regardless of logging levels has the
potential to cause unnecessary clutter in users' log files.
Add a better description of the method + style nitpicks.
@ruuskas ruuskas changed the title Spectral time [ENH] Multiple improvements to spectral_connectivity_time Nov 25, 2022
@ruuskas
Copy link
Contributor Author

ruuskas commented Nov 25, 2022

It looks like the CIs are not happy about the docs and loading the Matplotlib Qt backend fails with Python 3.7.

I can't check what's wrong with the docs it seems. Building locally, it complains about 'secondary_sidebar_items': ['page-toc'] added here.

@adam2392 adam2392 self-requested a review November 25, 2022 17:40
if padding:
pad_idx = int(np.floor(padding * sfreq / decim))
out = out[..., pad_idx:-pad_idx]
weights = weights[..., pad_idx:-pad_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind also adding a test for padding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Comment on lines 509 to 540
if m == 'coh':
s_xx = psd[w_x]
s_yy = psd[w_y]
coh = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \
np.sqrt(s_xx.mean(axis=-1, keepdims=True) *
s_yy.mean(axis=-1, keepdims=True))
out.append(coh)

if m == 'plv':
dphi_mean = dphi.mean(axis=-1, keepdims=True)
plv = np.abs(dphi_mean)
out.append(plv)

if m == 'ciplv':
rplv = np.abs(np.mean(np.real(dphi), axis=-1, keepdims=True))
iplv = np.abs(np.mean(np.imag(dphi), axis=-1, keepdims=True))
ciplv = iplv / (np.sqrt(1 - rplv ** 2))
out.append(ciplv)

if m == 'pli':
pli = np.abs(np.mean(np.sign(np.imag(s_xy)),
axis=-1, keepdims=True))
out.append(pli)

if m == 'wpli':
con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True))
con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True)
wpli = con_num / con_den
out.append(wpli)

if m == 'cs':
out.append(s_xy)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about keeping the basic pairwise "computation" functions as separate functions, so that way they can be self-documented in their docstring? E.g. keeping all the old functions?

And then to add say a new spectral-connectivity function, one would have to: i) add a new private function e.g. def _newpli(x, y) and then ii) update _parallel_con to add this option.

Might be easier to read. But also not 100% necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an option for sure. However, it should be done such that we avoid computing the cross-spectral density multiple times. This is slow due to the potentially large size of the arrays according to some profiling I did.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a look to see if this is possible then.

I wonder if it's possible to keep even the same "spectral-conn" base functions across 'epochs' and 'time' computation. I.e. there are the classes inside spectral_conn_epochs.py which also compute certain spectral connectivity. They seem to all operate on a pairwise nature, so perhaps we can even consolidate into one place, so then adding one connectivity function adds it for both 'epochs' and 'time'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid it wouldn't be practical to use the classes from spectral_connectivity_epochs. Those estimate connectivity as mean over epochs with the assumption that the connectivity measure is stationary within an Epoch. An estimator's accumulate method is therefore called only once per Epoch. Cross-spectral density is computed in a single time window spanning the whole Epoch (with the exception of cwt_morlet mode).

In spectral_connectivity_time the pairwise computation involves summing over the time axis and multitaper actually means convolution with a DPSS-tapered sinusoid (similar to Gaussian tapered sinusoid for morlet. Summing over the time axis using something like the accumulate function in epochs would necessitate a for-loop over time. There's also the possibility of fetching connectivity over time separately for each epoch, which of course doesn't exist in epochs.

In my opinion, accommodating both within the same classes would probably lead to major refactoring of spectral_connectivity_epochs and potentially make the code less readable.

Sorry if the explanation is too basic, I'm mostly writing it down to keep it straight for myself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about doing it this way?

def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs,
                  verbose, total, faverage, weights):
    """Compute spectral connectivity in parallel.

    Input signal w is of shape (n_chans, n_tapers, n_freqs, n_times)."""

    if 'coh' in method:
        # psd
        if weights is not None:
            psd = weights * w
            psd = psd * np.conj(psd)
            psd = psd.real.sum(axis=1)
            psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0)
        else:
            psd = w.real ** 2 + w.imag ** 2
            psd = np.squeeze(psd, axis=1)

        # smooth
        psd = _smooth_spectra(psd, kernel)

    def pairwise_con(w_x, w_y):
        # csd
        if weights is not None:
            s_xy = np.sum(weights * w[w_x] * np.conj(weights * w[w_y]), axis=0)
            s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0)
        else:
            s_xy = w[w_x] * np.conj(w[w_y])
            s_xy = np.squeeze(s_xy, axis=0)
        s_xy = _smooth_spectra(s_xy, kernel)
        out = []
        conn_func = {'plv': _plv, 'ciplv': _ciplv, 'pli': _pli, 'wpli': _wpli,
                     'coh': _coh, 'cs': _cs}
        for m in method:
            if m == 'coh':
                s_xx = psd[w_x]
                s_yy = psd[w_y]
                out.append(conn_func[m](s_xx, s_yy, s_xy))
            else:
                out.append(conn_func[m](s_xy))

        for i, _ in enumerate(out):
            # mean inside frequency sliding window (if needed)
            if isinstance(foi_idx, np.ndarray) and faverage:
                out[i] = _foi_average(out[i], foi_idx)
            # squeeze time dimension
            out[i] = out[i].squeeze(axis=-1)

        return out

    # only show progress if verbosity level is DEBUG
    if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10:
        total = None

    # define the function to compute in parallel
    parallel, p_fun, n_jobs = parallel_func(
        pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total)

    return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx))


def _plv(s_xy):
    s_xy = s_xy / np.abs(s_xy)
    plv = np.abs(s_xy.mean(axis=-1, keepdims=True))
    return plv


def _ciplv(s_xy):
    s_xy = s_xy / np.abs(s_xy)
    rplv = np.abs(np.mean(np.real(s_xy), axis=-1, keepdims=True))
    iplv = np.abs(np.mean(np.imag(s_xy), axis=-1, keepdims=True))
    ciplv = iplv / (np.sqrt(1 - rplv ** 2))
    return ciplv


def _pli(s_xy):
    pli = np.abs(np.mean(np.sign(np.imag(s_xy)),
                         axis=-1, keepdims=True))
    return pli


def _wpli(s_xy):
    con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True))
    con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True)
    wpli = con_num / con_den
    return wpli


def _coh(s_xx, s_yy, s_xy):
    con_num = np.abs(s_xy.mean(axis=-1, keepdims=True))
    con_den = np.sqrt(s_xx.mean(axis=-1, keepdims=True) *
                      s_yy.mean(axis=-1, keepdims=True))
    coh = con_num / con_den
    return coh


def _cs(s_xy):
    return s_xy.mean(axis=-1, keepdims=True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that looks better and more similar to the epochs file.

Idr but can we put the pairwise function outside? Nested functions imo are harder to read. But if that breaks the paralleliation via joblib, then nvm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try if the pairwise function works outside _parallel_con. Fully agree that nested functions are ugly.

Comment on lines 605 to 609
def _compute_freqs(n_times, sfreq, freqs, mode):
from scipy.fft import rfftfreq
# get frequencies of interest for the different modes
if freqs is not None:
if any(freqs > (sfreq / 2.)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is defined already in spectral.epochs.py. Can we use that one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function in epochs.py would always return rfftfreq(n_times, 1./sfreq) if multitaper is used. This results in a very dense frequency grid, making computation too slow to be practical for wide frequency bands. Therefore, I would like to allow freqs to be optionally set by the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add freqs into the existing function and then use it?

Having one private function makes the code more robust, easier to maintain, and less likely to introduce runtime bugs in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, but then we would have to prevent the use of cwt_freqs together with multitaper in spectral_connectivity_epochs. In that case freqs must be equal to rfftfreq(n_times, 1./sfreq) as that's the shape of the output of _csd_from_mt and _psd_from_mt.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is fine. So, we would check if freqs is not None and mode == 'multitaper' and raise an error.

The function is essentially doing the same thing regardless of which one it's being called from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the case only in spectral_connectivity_epochs (where multitaper frequencies must be defined by rfftfreq for the method to work). In spectral_connectivity_time there isn't really much reason to even use rfftfreq to obtain frequencies as it uses a different (time-resolved) multitaper implementation. If we require freqs anyway with morlet, it might make sense to go back to requiring them for multitaper as well given that, as far as I understand, in the time-resolved tfr implementation the difference between morlet and multitaper is only in the type of wavelet used (+ multiple tapers with multitaper).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a solution could be to add the connectivity paradigm (epochs vs time) to the parameters of _compute_freqs. Although the reasoning in my previous comment would still favor requiring freqs in Multitaper mode as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think requiring freqs could be fine. I'm not 100% sure I'm following wrt all the different possible combinations of parameters here, so I can take a look after you push up the proposed change?

@adam2392
Copy link
Member

It looks like the CIs are not happy about the docs and loading the Matplotlib Qt backend fails with Python 3.7.

I can't check what's wrong with the docs it seems. Building locally, it complains about 'secondary_sidebar_items': ['page-toc'] added here.

Should be fixed now on main.

@adam2392 adam2392 changed the title [ENH] Multiple improvements to spectral_connectivity_time [ENH] Multiple improvements to spectral_connectivity_time: ciPLV, and efficient computation of multiple metrics Dec 1, 2022
Individual functions for each connectivity methods.
This change makes joblib happy to do multithreading.
@adam2392
Copy link
Member

Feel free to ping me whenever you need me to take a look again!

@adam2392
Copy link
Member

adam2392 commented Jan 3, 2023

Happy new years @ruuskas ! Just wanted to check in on this. It would be great to get this in and make a new release w/ all your cool improvements!

@ruuskas
Copy link
Contributor Author

ruuskas commented Jan 4, 2023

Happy new year @adam2392 ! I have been focusing on other things so this has been lagging. I'll try to get the necessary changes done this week (maybe today if there are no blockers).

The user is required to specify the wavelet central frequencies in both
multitaper and cwt_morlet tfr mode. The reasoning is that the underlying
 tfr implementations are very similar. This is in contrast to
 spectral_connectivity_epochs, where multitaper assumes that the
 spectrum is stationary and therefore no wavelets are used.
@ruuskas
Copy link
Contributor Author

ruuskas commented Jan 6, 2023

Hi @adam2392 ! This should now be ready for review again. The freqs parameter is now required and I added a test for padding. It was a good idea since there was a bug hiding indeed.

Sphinx is not happy as there seems to be an issue with indentation in the docstring, which I couldn't figure out without changing config.

@drammock
Copy link
Member

drammock commented Jan 6, 2023

Sphinx is not happy as there seems to be an issue with indentation in the docstring, which I couldn't figure out without changing config.

The sphinx error:

sphinx.errors.ThemeError: An error happened in rendering the page api.
Reason: UndefinedError("'logo' is undefined")

is due to a change in sphinx 6.0 and is currently being addressed upstream in the website theme. Temporary workaround is to pin sphinx to <6.0

@ruuskas
Copy link
Contributor Author

ruuskas commented Jan 6, 2023

The Sphinx error I get locally is different. I can copy it verbatim tomorrow, but essentially "unexpected indentation" in the docstring. Indentation looks fine where it points to though.

Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Some minor changes and clean up and this should be good to go soon.

mne_connectivity/spectral/time.py Outdated Show resolved Hide resolved

Parameters
----------
data : array_like, shape (n_epochs, n_signals, n_times) | Epochs
The data from which to compute connectivity.
freqs : array_like
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be int, or float right? If so, what is the expected behavior? It just computes the connectivity at that frequency?

mne_connectivity/spectral/time.py Outdated Show resolved Hide resolved
mne_connectivity/spectral/time.py Show resolved Hide resolved
mne_connectivity/spectral/tests/test_spectral.py Outdated Show resolved Hide resolved
mne_connectivity/spectral/tests/test_spectral.py Outdated Show resolved Hide resolved
@adam2392
Copy link
Member

adam2392 commented Jan 6, 2023

@ruuskas I fixed the indentation error and other sphinx issue. You have to git pull to sync changes.

@ruuskas
Copy link
Contributor Author

ruuskas commented Jan 9, 2023

Hi @adam2392! I addressed all the things you pointed out and pushed changes. WDYT?

Co-authored-by: Adam Li <[email protected]>
Comment on lines +513 to +514
def _pairwise_con(w, psd, x, y, method, kernel, foi_idx,
faverage, weights):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even tho this is a private function, do you mind adding a basic docstring mainly for internal documentation purposes?

Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor nitpick, and then this LGTM for merging!

@adam2392 adam2392 merged commit ed40aef into mne-tools:main Jan 10, 2023
@adam2392
Copy link
Member

Thanks @ruuskas! Great addition and contribution.

@ruuskas
Copy link
Contributor Author

ruuskas commented Jan 10, 2023

Thanks @adam2392!

tsbinns pushed a commit to tsbinns/mne-connectivity that referenced this pull request Dec 15, 2023
… efficient computation of multiple metrics (mne-tools#115)

* Add ciPLV: Add the corrected imaginary Phase-Locking-Value into the list of
available connectivity metrics.

* Speed up computation: All connectivity measures are now computed with only a single
computation of pairwise cross spectrum.

* Add the option to specify freqs in all modes: In some scenarios, users might want to specify the frequencies for
time-frequency decomposition also when using multitapering. These
changes allow users to specify the 'freqs' parameter to override the
automatically determined frequencies.

* BUG: Average over CSD instead of connectivity

* Add option to use part of signal as padding: This adds the option to use the edges of the signal at each epoch as
padding. The purpose of this is to avoid edge effects generated by the
time-frequency transformation methods.

* Fix test bug, use 'freqs' instead of 'cwt_freqs'

* Fix bug with dpss windows: Sym is not a parameter of dpss_windows. (But is one of the underlying
scipy.signal.dpss)

* Only show progress bar if verbosity level is DEBUG: This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with
MNE-Python, where progress bars are not shown at INFO or higher logging
levels. Rendering the progress bar regardless of logging levels has the
potential to cause unnecessary clutter in users' log files.

* Require freqs in all tfr modes

The user is required to specify the wavelet central frequencies in both
multitaper and cwt_morlet tfr mode. The reasoning is that the underlying
 tfr implementations are very similar. This is in contrast to
 spectral_connectivity_epochs, where multitaper assumes that the
 spectrum is stationary and therefore no wavelets are used.

* Require mne>=1.3

Signed-off-by: Adam Li <[email protected]>
Co-authored-by: Adam Li <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants