From 18c9e320e1977ae7c87aa592dea28825d65cef04 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 01:13:30 +0200 Subject: [PATCH 01/17] fix: correct CCA behavior --- docs/source/whats_new.rst | 2 +- moabb/pipelines/classification.py | 85 ++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 16f3d6bfb..4299e4203 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -21,7 +21,7 @@ Enhancements Bugs ~~~~ -- None +- Correct :class:`moabb.pipelines.classification.SSVEP_CCA` behavior (:gh:`XXX`by `Sylvain Chevallier`_) API changes ~~~~~~~~~~~ diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index d9afff3fa..d044ce43f 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -1,11 +1,13 @@ import numpy as np import scipy.linalg as linalg from joblib import Parallel, delayed +from mne import BaseEpochs from pyriemann.estimation import Covariances from pyriemann.utils.covariance import covariances from pyriemann.utils.mean import mean_covariance from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.cross_decomposition import CCA +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_is_fitted from .utils import filterbank @@ -44,57 +46,102 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): https://doi.org/10.1088/1741-2560/6/4/046002 """ - def __init__(self, interval, freqs, n_harmonics=3): + def __init__(self, n_harmonics=3): self.Yf = dict() self.cca = CCA(n_components=1) - self.interval = interval - self.slen = interval[1] - interval[0] - self.freqs = freqs + # self.interval = interval + # self.slen = interval[1] - interval[0] + # self.freqs = freqs self.n_harmonics = n_harmonics self.classes_ = [] - self.one_hot = {} - for i, k in enumerate(freqs.keys()): - self.classes_.append(i) - self.one_hot[k] = i + self.one_hot_ = {} + self.le_ = self.slen_ = self.freqs_ = None + # self.one_hot = {} + # for i, k in enumerate(freqs.keys()): + # self.classes_.append(i) + # self.one_hot[k] = i def fit(self, X, y, sample_weight=None): """Compute reference sinusoid signal. These sinusoid are generated for each frequency in the dataset - """ - n_times = X.shape[2] - for f in self.freqs: + Parameters + ---------- + X : MNE Epochs + The training data as MNE Epochs object. + y : unused, only for compatibility with scikit-learn + + Returns + ------- + self: SSVEP_CCA object + Instance of classifier. + """ + if not isinstance(X, BaseEpochs): + raise ValueError("X should be an MNE Epochs object.") + + self.slen_ = X.times[-1] - X.times[0] + n_times = len(X.times) + self.freqs_ = list(X.event_id.keys()) + self.le_ = LabelEncoder().fit(self.freqs_) + # self.le_.fit(self.freqs_) + self.classes_ = self.le_.transform(self.freqs_) + for i, k in zip(self.freqs_, self.le_.transform(self.freqs_)): + self.one_hot_[i] = k + + for f in self.freqs_: if f.replace(".", "", 1).isnumeric(): freq = float(f) yf = [] for h in range(1, self.n_harmonics + 1): yf.append( - np.sin(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times)) + np.sin(2 * np.pi * freq * h * np.linspace(0, self.slen_, n_times)) ) yf.append( - np.cos(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times)) + np.cos(2 * np.pi * freq * h * np.linspace(0, self.slen_, n_times)) ) self.Yf[f] = np.array(yf) return self def predict(self, X): - """Predict is made by taking the maximum correlation coefficient.""" + """Predict is made by taking the maximum correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + y : list of int + Predicted labels. + """ y = [] for x in X: corr_f = {} - for f in self.freqs: + for f in self.freqs_: if f.replace(".", "", 1).isnumeric(): S_x, S_y = self.cca.fit_transform(x.T, self.Yf[f].T) corr_f[f] = np.corrcoef(S_x.T, S_y.T)[0, 1] - y.append(self.one_hot[max(corr_f, key=corr_f.get)]) + y.append(self.one_hot_[max(corr_f, key=corr_f.get)]) return y def predict_proba(self, X): - """Probability could be computed from the correlation coefficient.""" - P = np.zeros(shape=(len(X), len(self.freqs))) + """Probability could be computed from the correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + proba : ndarray of shape (n_trials, n_classes) + probability of each class for each trial. + """ + P = np.zeros(shape=(len(X), len(self.freqs_))) for i, x in enumerate(X): - for j, f in enumerate(self.freqs): + for j, f in enumerate(self.freqs_): if f.replace(".", "", 1).isnumeric(): S_x, S_y = self.cca.fit_transform(x.T, self.Yf[f].T) P[i, j] = np.corrcoef(S_x.T, S_y.T)[0, 1] From cfcc6586bc1bed91ca2cc110b7c488785bd5ba7f Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 01:55:04 +0200 Subject: [PATCH 02/17] fix: correct yml CCA pipeline, use epochs in benchmarks for this ppl --- moabb/benchmark.py | 11 ++++++++++- pipelines/CCA-SSVEP.yml | 2 -- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/moabb/benchmark.py b/moabb/benchmark.py index f2493cfa3..a64688988 100644 --- a/moabb/benchmark.py +++ b/moabb/benchmark.py @@ -31,6 +31,15 @@ log = logging.getLogger(__name__) +def _ppl_needs_epochs(pn): + """Check if the pipeline needs MNE epochs as input.""" + ppl_with_epochs = ["braindecode", "Keras", "SSVEP CCA"] + if any(s in pn for s in ppl_with_epochs): + return True + else: + return False + + def benchmark( # noqa: C901 pipelines="./pipelines/", evaluations=None, @@ -165,7 +174,7 @@ def benchmark( # noqa: C901 ppl_with_epochs, ppl_with_array = {}, {} for pn, pv in prdgms[paradigm].items(): - if "braindecode" in pn or "Keras" in pn: + if _ppl_needs_epochs(pn): ppl_with_epochs[pn] = pv else: ppl_with_array[pn] = pv diff --git a/pipelines/CCA-SSVEP.yml b/pipelines/CCA-SSVEP.yml index bae157483..87ed8f9a1 100644 --- a/pipelines/CCA-SSVEP.yml +++ b/pipelines/CCA-SSVEP.yml @@ -8,5 +8,3 @@ pipeline: from: moabb.pipelines.classification parameters: n_harmonics: 3 - interval: [2, 4] - freqs: {"13": 2, "17":3, "21":4} From 4da8d437f09472aefc2d30dac8b180aefd5844ea Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 01:57:41 +0200 Subject: [PATCH 03/17] enh: update docstrings --- moabb/pipelines/classification.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index d044ce43f..d079b4d0d 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -23,18 +23,15 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): Parameters ---------- - interval : list of length 2 - List of form [tmin, tmax]. With tmin and tmax as defined in the SSVEP - paradigm :meth:`moabb.paradigms.SSVEP` - - freqs : dict with n_classes keys - Frequencies corresponding to the SSVEP stimulation frequencies. - They are used to identify SSVEP classes presents in the data. - - n_harmonics: int + n_harmonics: int, default=3 Number of stimulation frequency's harmonics to be used in the generation of the CCA reference signal. + Attributes + ---------- + classes_: list of int + List of unique classes present in the training data. + References ---------- @@ -44,22 +41,20 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): canonical correlation analysis method. Journal of neural engineering, 6(4), 046002. https://doi.org/10.1088/1741-2560/6/4/046002 + + Notes + ----- + .. versionadded:: 1.1.0 + Use MNE Epochs object as input data instead of numpy array. """ def __init__(self, n_harmonics=3): self.Yf = dict() self.cca = CCA(n_components=1) - # self.interval = interval - # self.slen = interval[1] - interval[0] - # self.freqs = freqs self.n_harmonics = n_harmonics self.classes_ = [] self.one_hot_ = {} self.le_ = self.slen_ = self.freqs_ = None - # self.one_hot = {} - # for i, k in enumerate(freqs.keys()): - # self.classes_.append(i) - # self.one_hot[k] = i def fit(self, X, y, sample_weight=None): """Compute reference sinusoid signal. @@ -149,8 +144,7 @@ def predict_proba(self, X): class SSVEP_TRCA(BaseEstimator, ClassifierMixin): - """Classifier based on the Task-Related Component Analysis method [1]_ for - SSVEP. + """Task-Related Component Analysis method [1]_ for SSVEP. Parameters ---------- @@ -394,6 +388,7 @@ def fit(self, X, y): # Get shape of X and labels n_trials, n_channels, n_samples = X.shape + # self.sfreq_ = X.info["sfreq"] self.sfreq = int(n_samples / self.slen) self.sfreq = self.sfreq / self.downsample From 2c0db527195ff4ae84648d519095d980fbdc1bad Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 11:09:31 +0200 Subject: [PATCH 04/17] fix: use epochs for trca --- moabb/pipelines/classification.py | 133 +++++++++++++++++------------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index d079b4d0d..3f61cadf2 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -79,7 +79,6 @@ def fit(self, X, y, sample_weight=None): n_times = len(X.times) self.freqs_ = list(X.event_id.keys()) self.le_ = LabelEncoder().fit(self.freqs_) - # self.le_.fit(self.freqs_) self.classes_ = self.le_.transform(self.freqs_) for i, k in zip(self.freqs_, self.le_.transform(self.freqs_)): self.one_hot_[i] = k @@ -224,23 +223,24 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): def __init__( self, - interval, - freqs, - downsample=1, is_ensemble=True, method="original", estimator="scm", ): - self.freqs = freqs - self.peaks = np.array([float(f) for f in freqs.keys()]) - self.n_fbands = len(self.peaks) - self.downsample = downsample - self.interval = interval - self.slen = interval[1] - interval[0] + # self.freqs = freqs + # self.peaks = np.array([float(f) for f in freqs.keys()]) + # self.n_fbands = len(self.peaks) + # self.downsample = downsample + # self.interval = interval + # self.slen = interval[1] - interval[0] self.is_ensemble = is_ensemble - self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] + # self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] self.estimator = estimator self.method = method + self.fb_coefs, self.one_hot_ = [], {} + self.sfreq_, self.freqs_, self.peaks_, self.n_fbands = None, None, None, None + self.le_, self.classes_, self.n_classes = None, None, None + self.templates_, self.weights_ = None, None def _Q_S_estim(self, data): # Check if X is a single trial (test data) or not @@ -383,34 +383,43 @@ def fit(self, X, y): self: CCA object Instance of classifier. """ - # Downsample data - X = X[:, :, :: self.downsample] - - # Get shape of X and labels - n_trials, n_channels, n_samples = X.shape - # self.sfreq_ = X.info["sfreq"] - - self.sfreq = int(n_samples / self.slen) - self.sfreq = self.sfreq / self.downsample + if not isinstance(X, BaseEpochs): + raise ValueError("X should be an MNE Epochs object.") - self.classes_ = np.unique(y) + # n_trials = len(X) + n_channels, n_samples = X.info["nchan"], len(X.times) + self.sfreq_ = X.info["sfreq"] + self.freqs_ = list(X.event_id.keys()) + self.peaks_ = np.array([float(f) for f in self.freqs_]) + self.n_fbands = len(self.peaks_) + self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] + self.le_ = LabelEncoder().fit(self.freqs_) + self.classes_ = self.le_.transform(self.freqs_) self.n_classes = len(self.classes_) + for i, k in zip(self.freqs_, self.classes_): + self.one_hot_[i] = k # Initialize the final arrays self.templates_ = np.zeros((self.n_classes, self.n_fbands, n_channels, n_samples)) self.weights_ = np.zeros((self.n_fbands, self.n_classes, n_channels)) - for class_idx in self.classes_: - X_cal = X[y == class_idx] # Select data with a specific label + # for class_idx in self.classes_: + for freq, k in self.one_hot_.items(): + # X_cal = X[y == class_idx] # Select data with a specific label + X_cal = X[freq] # Select data with a specific label # Filterbank approach for band_n in range(self.n_fbands): # Filter the data and compute TRCA - X_filter = filterbank(X_cal, self.sfreq, band_n, self.peaks) + X_filter = filterbank( + X_cal.get_data(copy=False), self.sfreq_, band_n, self.peaks_ + ) w_best, _ = self._compute_trca(X_filter) # Get template by averaging trials and take the best filter for this band - self.templates_[class_idx, band_n, :, :] = np.mean(X_filter, axis=0) - self.weights_[band_n, class_idx, :] = w_best + # self.templates_[class_idx, band_n, :, :] = np.mean(X_filter, axis=0) + # self.weights_[band_n, class_idx, :] = w_best + self.templates_[k, band_n, :, :] = np.mean(X_filter, axis=0) + self.weights_[band_n, k, :] = w_best return self @@ -440,34 +449,38 @@ def predict(self, X): # Check is fit had been called check_is_fitted(self) - # Check if X is a single trial or not - if X.ndim == 2: - X = X[np.newaxis, ...] - - # Downsample data - X = X[:, :, :: self.downsample] + # # Check if X is a single trial or not + # if X.ndim == 2: + # X = X[np.newaxis, ...] + # + # # Downsample data + # X = X[:, :, :: self.downsample] # Get test data shape - n_trials, _, _ = X.shape + # n_trials = len(X) + # n_trials, _, _ = X.shape # Initialize pred array y_pred = [] - for trial_n in range(n_trials): + # for trial_n in range(n_trials): + for X_test in X: # Pick trial - X_test = X[trial_n, :, :] + # X_test = X[trial_n, :, :] # Initialize correlations array corr_array = np.zeros((self.n_fbands, self.n_classes)) # Filter the data in the corresponding band for band_n in range(self.n_fbands): - X_filter = filterbank(X_test, self.sfreq, band_n, self.peaks) + X_filter = filterbank(X_test, self.sfreq_, band_n, self.peaks_) # Compute correlation with all the templates and bands - for class_idx in range(self.n_classes): + # for class_idx in range(self.n_classes): + for freq, k in self.one_hot_.items(): # Get the corresponding template - template = np.squeeze(self.templates_[class_idx, band_n, :, :]) + # template = np.squeeze(self.templates_[class_idx, band_n, :, :]) + template = np.squeeze(self.templates_[k, band_n, :, :]) if self.is_ensemble: w = np.squeeze( @@ -475,7 +488,8 @@ def predict(self, X): ).T # (n_classes, n_channel) else: w = np.squeeze( - self.weights_[band_n, class_idx, :] + # self.weights_[band_n, class_idx, :] + self.weights_[band_n, k, :] ).T # (n_channel,) # Compute 2D correlation of spatially filtered testdata with ref @@ -483,14 +497,14 @@ def predict(self, X): np.dot(X_filter.T, w).flatten(), np.dot(template.T, w).flatten(), ) - corr_array[band_n, class_idx] = r[0, 1] + corr_array[band_n, k] = r[0, 1] # Fusion for the filterbank analysis rho = np.dot(self.fb_coefs, corr_array) - # Select the maximal value and append to preddictions + # Select the maximal value and append to predictions tau = np.argmax(rho) - y_pred.append(tau) + y_pred.append(self.one_hot_[self.freqs_[tau]]) return y_pred @@ -521,49 +535,52 @@ def predict_proba(self, X): check_is_fitted(self) # Check if X is a single trial or not - if X.ndim == 2: - X = X[np.newaxis, ...] - - # Downsample data - X = X[:, :, :: self.downsample] + # if X.ndim == 2: + # X = X[np.newaxis, ...] + # + # # Downsample data + # X = X[:, :, :: self.downsample] # Get test data shape - n_trials, _, _ = X.shape + # n_trials, _, _ = X.shape + n_trials = len(X) # Initialize pred array - y_pred = np.zeros((n_trials, len(self.peaks))) + y_pred = np.zeros((n_trials, self.n_classes)) - for trial_n in range(n_trials): + # for trial_n in range(n_trials): + for trial_n, X_test in enumerate(X): # Pick trial - X_test = X[trial_n, :, :] + # X_test = X[trial_n, :, :] # Initialize correlations array corr_array = np.zeros((self.n_fbands, self.n_classes)) # Filter the data in the corresponding band for band_n in range(self.n_fbands): - X_filter = filterbank(X_test, self.sfreq, band_n, self.peaks) + X_filter = filterbank(X_test, self.sfreq_, band_n, self.peaks_) # Compute correlation with all the templates and bands - for class_idx in range(self.n_classes): + # for class_idx in range(self.n_classes): + for freq, k in self.one_hot_.items(): # Get the corresponding template - template = np.squeeze(self.templates_[class_idx, band_n, :, :]) + # template = np.squeeze(self.templates_[class_idx, band_n, :, :]) + template = np.squeeze(self.templates_[k, band_n, :, :]) if self.is_ensemble: w = np.squeeze( self.weights_[band_n, :, :] ).T # (n_class, n_channel) else: - w = np.squeeze( - self.weights_[band_n, class_idx, :] - ).T # (n_channel,) + w = np.squeeze(self.weights_[band_n, k, :]).T # (n_channel,) # Compute 2D correlation of spatially filtered testdata with ref r = np.corrcoef( np.dot(X_filter.T, w).flatten(), np.dot(template.T, w).flatten(), ) - corr_array[band_n, class_idx] = r[0, 1] + # corr_array[band_n, class_idx] = r[0, 1] + corr_array[band_n, k] = r[0, 1] normalized_coefs = self.fb_coefs / (np.sum(self.fb_coefs)) # Fusion for the filterbank analysis From efb693b16ae117d3648b20050516b120193ee127 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 17:33:12 +0200 Subject: [PATCH 05/17] fix: correct TRCA labels --- moabb/pipelines/classification.py | 19 ++++++++++++------- moabb/pipelines/utils.py | 22 ++++++++++++++++------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index 3f61cadf2..aa8d2e6b0 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -223,6 +223,7 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): def __init__( self, + n_fbands=5, is_ensemble=True, method="original", estimator="scm", @@ -234,11 +235,14 @@ def __init__( # self.interval = interval # self.slen = interval[1] - interval[0] self.is_ensemble = is_ensemble - # self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] self.estimator = estimator self.method = method - self.fb_coefs, self.one_hot_ = [], {} - self.sfreq_, self.freqs_, self.peaks_, self.n_fbands = None, None, None, None + self.n_fbands = n_fbands + # self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] + self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] + # self.fb_coefs = [] + self.one_hot_, self.one_inv_ = {}, {} + self.sfreq_, self.freqs_, self.peaks_ = None, None, None self.le_, self.classes_, self.n_classes = None, None, None self.templates_, self.weights_ = None, None @@ -391,13 +395,14 @@ def fit(self, X, y): self.sfreq_ = X.info["sfreq"] self.freqs_ = list(X.event_id.keys()) self.peaks_ = np.array([float(f) for f in self.freqs_]) - self.n_fbands = len(self.peaks_) + # self.n_fbands = len(self.peaks_) self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] self.le_ = LabelEncoder().fit(self.freqs_) self.classes_ = self.le_.transform(self.freqs_) self.n_classes = len(self.classes_) for i, k in zip(self.freqs_, self.classes_): self.one_hot_[i] = k + self.one_inv_[k] = i # Initialize the final arrays self.templates_ = np.zeros((self.n_classes, self.n_fbands, n_channels, n_samples)) @@ -500,11 +505,11 @@ def predict(self, X): corr_array[band_n, k] = r[0, 1] # Fusion for the filterbank analysis - rho = np.dot(self.fb_coefs, corr_array) + self.rho = np.dot(self.fb_coefs, corr_array) # Select the maximal value and append to predictions - tau = np.argmax(rho) - y_pred.append(self.one_hot_[self.freqs_[tau]]) + self.tau = np.argmax(self.rho) + y_pred.append(self.one_hot_[self.one_inv_[self.tau]]) return y_pred diff --git a/moabb/pipelines/utils.py b/moabb/pipelines/utils.py index ecb2b41fc..75f573550 100644 --- a/moabb/pipelines/utils.py +++ b/moabb/pipelines/utils.py @@ -316,27 +316,37 @@ def filterbank(X, sfreq, idx_fb, peaks): sfreq = sfreq / 2 peaks = np.sort(peaks) + min_freq = np.min(peaks) max_freq = np.max(peaks) if max_freq < 40: - top = 40 + top = 100 else: - top = 60 + top = 115 # Check for Nyquist if top >= sfreq: top = sfreq - 10 # Lowcut frequencies for the pass band (depends on the frequencies of SSVEP) # No more than 3dB loss in the passband - passband = [peaks[i] - 1 for i in range(len(peaks))] + diff = max_freq - min_freq + passband = [min_freq - 2 + x * diff for x in range(7)] + # passband = [peaks[i] - 1 for i in range(len(peaks))] # At least 40db attenuation in the stopband - stopband = [peaks[i] - 2 for i in range(len(peaks))] + if min_freq - 4 > 0: + stopband = [ + min_freq - 4 + x * (diff - 2) if x < 3 else min_freq - 4 + x * diff + for x in range(7) + ] + else: + stopband = [2 + x * (diff - 2) if x < 3 else 2 + x * diff for x in range(7)] + # stopband = [peaks[i] - 2 for i in range(len(peaks))] Wp = [passband[idx_fb] / sfreq, top / sfreq] - Ws = [stopband[idx_fb] / sfreq, (top + 20) / sfreq] + Ws = [stopband[idx_fb] / sfreq, (top + 7) / sfreq] - N, Wn = scp.cheb1ord(Wp, Ws, 3, 15) # Chebyshev type I filter order selection. + N, Wn = scp.cheb1ord(Wp, Ws, 3, 40) # Chebyshev type I filter order selection. B, A = scp.cheby1(N, 0.5, Wn, btype="bandpass") # Chebyshev type I filter design From 6f5e1135f010a14bf68b1ed18592b1af4f63fb62 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 17:48:35 +0200 Subject: [PATCH 06/17] enh: clean trca code, correct docstrings --- moabb/pipelines/classification.py | 82 ++++++------------------------- 1 file changed, 16 insertions(+), 66 deletions(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index aa8d2e6b0..257d03e8a 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -44,8 +44,8 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): Notes ----- - .. versionadded:: 1.1.0 - Use MNE Epochs object as input data instead of numpy array. + .. versionchanged:: 1.1.0 + Use MNE Epochs object as input data instead of numpy array, fix label encoding. """ def __init__(self, n_harmonics=3): @@ -147,17 +147,9 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Parameters ---------- - sfreq : float - Sampling frequency of the data to be analyzed. - - freqs : dict with n_classes keys - Frequencies corresponding to the SSVEP components. These are - necessary to design the filterbank bands. - - downsample: int, default=1 - Factor by which downsample the data. A downsample value of N will result - on a sampling frequency of (sfreq // N) by taking one sample every N of - the original data. In the original TRCA paper [1]_ data are at 250Hz. + n_fbands: int, default=5 + Number of sub-bands to divide the SSVEP frequencies, with filterbank + approach. is_ensemble: bool, default=False If True, predict on new data using the Ensemble-TRCA method described @@ -182,8 +174,6 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): is used. So method='original' and regul='scm' is similar to original implementation. - - Attributes ---------- fb_coefs : list of len (n_fbands) @@ -204,8 +194,11 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Weight coefficients for the different electrodes which are used as spatial filters for the data. + freqs_: list of str + List of unique frequencies present in the training data. + Reference - ---------- + --------- .. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung, "Enhancing detection of SSVEPs for a high-speed brain speller using @@ -219,6 +212,9 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Notes ----- .. versionadded:: 0.4.4 + + .. versionchanged:: 1.1.1 + TRCA implementation works with MNE Epochs object, fix labels encoding issue. """ def __init__( @@ -228,19 +224,11 @@ def __init__( method="original", estimator="scm", ): - # self.freqs = freqs - # self.peaks = np.array([float(f) for f in freqs.keys()]) - # self.n_fbands = len(self.peaks) - # self.downsample = downsample - # self.interval = interval - # self.slen = interval[1] - interval[0] self.is_ensemble = is_ensemble self.estimator = estimator self.method = method self.n_fbands = n_fbands - # self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] - # self.fb_coefs = [] self.one_hot_, self.one_inv_ = {}, {} self.sfreq_, self.freqs_, self.peaks_ = None, None, None self.le_, self.classes_, self.n_classes = None, None, None @@ -374,7 +362,7 @@ def fit(self, X, y): Parameters ---------- - X : ndarray of shape (n_trials, n_channels, n_samples) + X : MNE Epochs Training data. Trials are grouped by class, divided in n_fbands bands by the filterbank approach and then used to calculate weight vectors and templates for each class and band. @@ -390,12 +378,10 @@ def fit(self, X, y): if not isinstance(X, BaseEpochs): raise ValueError("X should be an MNE Epochs object.") - # n_trials = len(X) n_channels, n_samples = X.info["nchan"], len(X.times) self.sfreq_ = X.info["sfreq"] self.freqs_ = list(X.event_id.keys()) self.peaks_ = np.array([float(f) for f in self.freqs_]) - # self.n_fbands = len(self.peaks_) self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] self.le_ = LabelEncoder().fit(self.freqs_) self.classes_ = self.le_.transform(self.freqs_) @@ -410,8 +396,8 @@ def fit(self, X, y): # for class_idx in self.classes_: for freq, k in self.one_hot_.items(): - # X_cal = X[y == class_idx] # Select data with a specific label X_cal = X[freq] # Select data with a specific label + # Filterbank approach for band_n in range(self.n_fbands): # Filter the data and compute TRCA @@ -421,8 +407,6 @@ def fit(self, X, y): w_best, _ = self._compute_trca(X_filter) # Get template by averaging trials and take the best filter for this band - # self.templates_[class_idx, band_n, :, :] = np.mean(X_filter, axis=0) - # self.weights_[band_n, class_idx, :] = w_best self.templates_[k, band_n, :, :] = np.mean(X_filter, axis=0) self.weights_[band_n, k, :] = w_best @@ -437,8 +421,8 @@ def predict(self, X): Parameters ---------- - X : ndarray of shape (n_trials, n_channels, n_samples) - Testing data. This will be divided in self.n_fbands using the filter- bank approach, + X : MNE Epochs + Testing data. This will be divided in self.n_fbands using the filterbank approach, then it will be transformed by the different spatial filters and compared to the previously fit templates according to the selected method for analysis (ensemble or not). Finally, correlation scores for all sub-bands of each class will be combined, @@ -454,25 +438,10 @@ def predict(self, X): # Check is fit had been called check_is_fitted(self) - # # Check if X is a single trial or not - # if X.ndim == 2: - # X = X[np.newaxis, ...] - # - # # Downsample data - # X = X[:, :, :: self.downsample] - - # Get test data shape - # n_trials = len(X) - # n_trials, _, _ = X.shape - # Initialize pred array y_pred = [] - # for trial_n in range(n_trials): for X_test in X: - # Pick trial - # X_test = X[trial_n, :, :] - # Initialize correlations array corr_array = np.zeros((self.n_fbands, self.n_classes)) @@ -481,10 +450,8 @@ def predict(self, X): X_filter = filterbank(X_test, self.sfreq_, band_n, self.peaks_) # Compute correlation with all the templates and bands - # for class_idx in range(self.n_classes): for freq, k in self.one_hot_.items(): # Get the corresponding template - # template = np.squeeze(self.templates_[class_idx, band_n, :, :]) template = np.squeeze(self.templates_[k, band_n, :, :]) if self.is_ensemble: @@ -538,26 +505,12 @@ def predict_proba(self, X): # Check is fit had been called check_is_fitted(self) - - # Check if X is a single trial or not - # if X.ndim == 2: - # X = X[np.newaxis, ...] - # - # # Downsample data - # X = X[:, :, :: self.downsample] - - # Get test data shape - # n_trials, _, _ = X.shape n_trials = len(X) # Initialize pred array y_pred = np.zeros((n_trials, self.n_classes)) - # for trial_n in range(n_trials): for trial_n, X_test in enumerate(X): - # Pick trial - # X_test = X[trial_n, :, :] - # Initialize correlations array corr_array = np.zeros((self.n_fbands, self.n_classes)) @@ -566,10 +519,8 @@ def predict_proba(self, X): X_filter = filterbank(X_test, self.sfreq_, band_n, self.peaks_) # Compute correlation with all the templates and bands - # for class_idx in range(self.n_classes): for freq, k in self.one_hot_.items(): # Get the corresponding template - # template = np.squeeze(self.templates_[class_idx, band_n, :, :]) template = np.squeeze(self.templates_[k, band_n, :, :]) if self.is_ensemble: @@ -584,7 +535,6 @@ def predict_proba(self, X): np.dot(X_filter.T, w).flatten(), np.dot(template.T, w).flatten(), ) - # corr_array[band_n, class_idx] = r[0, 1] corr_array[band_n, k] = r[0, 1] normalized_coefs = self.fb_coefs / (np.sum(self.fb_coefs)) From 511dfe106033b1ae87f53a8e99be11d9165f1a37 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 17:57:23 +0200 Subject: [PATCH 07/17] fix: correct TRCA pipeline --- docs/source/whats_new.rst | 2 +- pipelines/TRCA-SSVEP.yml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 4299e4203..49f6652d5 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -21,7 +21,7 @@ Enhancements Bugs ~~~~ -- Correct :class:`moabb.pipelines.classification.SSVEP_CCA` behavior (:gh:`XXX`by `Sylvain Chevallier`_) +- Correct :class:`moabb.pipelines.classification.SSVEP_CCA` and :class:`moabb.pipelines.classification.SSVEP_TRCA` behavior (:gh:`XXX`by `Sylvain Chevallier`_) API changes ~~~~~~~~~~~ diff --git a/pipelines/TRCA-SSVEP.yml b/pipelines/TRCA-SSVEP.yml index 5592b0798..eb55bb7fc 100644 --- a/pipelines/TRCA-SSVEP.yml +++ b/pipelines/TRCA-SSVEP.yml @@ -9,5 +9,7 @@ pipeline: - name: SSVEP_TRCA from: moabb.pipelines.classification parameters: - interval: [2, 4] + n_fbands: 5 + is_ensemble: True + method: "riemann" freqs: {"13":2, "17":3, "21":4} From 3096b364868408795fee421c6f4c3fb09026336e Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 17:59:49 +0200 Subject: [PATCH 08/17] fix: add trca as an epoch pipeline for benchmark --- moabb/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moabb/benchmark.py b/moabb/benchmark.py index a64688988..cf19c83f3 100644 --- a/moabb/benchmark.py +++ b/moabb/benchmark.py @@ -33,7 +33,7 @@ def _ppl_needs_epochs(pn): """Check if the pipeline needs MNE epochs as input.""" - ppl_with_epochs = ["braindecode", "Keras", "SSVEP CCA"] + ppl_with_epochs = ["braindecode", "Keras", "SSVEP CCA", "TRCA-SSVEP"] if any(s in pn for s in ppl_with_epochs): return True else: From 7c1292ccf57c15120d4bc4c6b8bc66e3ac3cd461 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Mon, 10 Jun 2024 18:06:15 +0200 Subject: [PATCH 09/17] fix: correct trca pipeline --- pipelines/TRCA-SSVEP.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/pipelines/TRCA-SSVEP.yml b/pipelines/TRCA-SSVEP.yml index eb55bb7fc..29a1871a3 100644 --- a/pipelines/TRCA-SSVEP.yml +++ b/pipelines/TRCA-SSVEP.yml @@ -12,4 +12,3 @@ pipeline: n_fbands: 5 is_ensemble: True method: "riemann" - freqs: {"13":2, "17":3, "21":4} From 8a3d263ec70961cd4c1f4c2e031bfb2c4a35aad6 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Tue, 11 Jun 2024 08:48:14 +0200 Subject: [PATCH 10/17] enh: add ssvep context for resample and no rest --- contexts/ssvep_kalunga_norest.yml | 6 ++++++ contexts/ssvep_resample.yml | 2 ++ 2 files changed, 8 insertions(+) create mode 100644 contexts/ssvep_kalunga_norest.yml create mode 100644 contexts/ssvep_resample.yml diff --git a/contexts/ssvep_kalunga_norest.yml b/contexts/ssvep_kalunga_norest.yml new file mode 100644 index 000000000..fbbb03b75 --- /dev/null +++ b/contexts/ssvep_kalunga_norest.yml @@ -0,0 +1,6 @@ +SSVEP: + events: + - "13" + - "17" + - "21" + n_classes: 3 diff --git a/contexts/ssvep_resample.yml b/contexts/ssvep_resample.yml new file mode 100644 index 000000000..5152b898b --- /dev/null +++ b/contexts/ssvep_resample.yml @@ -0,0 +1,2 @@ +SSVEP: + resample: 250.0 From e78e890a0fec8295d67c565970e688620cc4d31c Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Tue, 11 Jun 2024 10:46:03 +0200 Subject: [PATCH 11/17] fix: correct Mset-CCA pipelines labels, use MNE Epochs, update benchmark --- moabb/benchmark.py | 2 +- moabb/pipelines/classification.py | 86 ++++++++++++++++++++++++------- pipelines/MsetCCA-SSVEP.yml | 4 +- 3 files changed, 72 insertions(+), 20 deletions(-) diff --git a/moabb/benchmark.py b/moabb/benchmark.py index cf19c83f3..ddfd3b264 100644 --- a/moabb/benchmark.py +++ b/moabb/benchmark.py @@ -33,7 +33,7 @@ def _ppl_needs_epochs(pn): """Check if the pipeline needs MNE epochs as input.""" - ppl_with_epochs = ["braindecode", "Keras", "SSVEP CCA", "TRCA-SSVEP"] + ppl_with_epochs = ["braindecode", "Keras", "SSVEP CCA", "TRCA-SSVEP", "MsetCCA-SSVEP"] if any(s in pn for s in ppl_with_epochs): return True else: diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index 257d03e8a..abf7a59d1 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -389,6 +389,10 @@ def fit(self, X, y): for i, k in zip(self.freqs_, self.classes_): self.one_hot_[i] = k self.one_inv_[k] = i + if self.n_fbands > len(self.peaks_): + raise ValueError( + "Number of filterbank bands should be less or equal to the number of peaks." + ) # Initialize the final arrays self.templates_ = np.zeros((self.n_classes, self.n_fbands, n_channels, n_samples)) @@ -400,6 +404,7 @@ def fit(self, X, y): # Filterbank approach for band_n in range(self.n_fbands): + print(self.sfreq_, band_n, self.peaks_) # Filter the data and compute TRCA X_filter = filterbank( X_cal.get_data(copy=False), self.sfreq_, band_n, self.peaks_ @@ -576,15 +581,21 @@ class SSVEP_MsetCCA(BaseEstimator, ClassifierMixin): Parameters ---------- - freqs : dict with n_classes keys - Frequencies corresponding to the SSVEP stimulation frequencies. - They are used to identify SSVEP classes presents in the data. - - n_filters: int + n_filters: int, default=1 Number of multisets spatial filters used per sample data. It corresponds to the number of eigen vectors taken the solution of the MAXVAR objective function as formulated in Eq.5 in [1]_. + n_jobs: int, default=1 + Number of jobs to run whitening in parallel. + + Attributes + ---------- + classes_ : ndarray of shape (n_classes,) + Array with the class labels extracted at fit time. + + freqs_: list of str + List of unique frequencies present in the training data. References ---------- @@ -599,20 +610,38 @@ class SSVEP_MsetCCA(BaseEstimator, ClassifierMixin): .. versionadded:: 0.5.0 """ - def __init__(self, freqs, n_filters=1, n_jobs=1): + def __init__(self, n_filters=1, n_jobs=1): self.n_jobs = n_jobs self.n_filters = n_filters - self.freqs = freqs self.cca = CCA(n_components=1) + self.freqs_, self.le_, self.classes_ = None, None, None + self.one_hot_, self.Ym = {}, {} def fit(self, X, y, sample_weight=None): - """Compute the optimized reference signal at each stimulus - frequency.""" - self.classes_ = np.unique(y) - self.one_hot = {} - for i, k in enumerate(self.classes_): - self.one_hot[k] = i - n_trials, n_channels, n_times = X.shape + """Compute the optimized reference signal at each stimulus frequency. + + Parameters + ---------- + X : MNE Epochs + The training data as MNE Epochs object. + + y : np.ndarray of shape (n_trials,) + The target labels for each trial. + + Returns + ------- + self: SSVEP_MsetCCA object + Instance of classifier. + """ + if not isinstance(X, BaseEpochs): + raise ValueError("X should be an MNE Epochs object.") + + self.freqs_ = list(X.event_id.keys()) + self.le_ = LabelEncoder().fit(self.freqs_) + self.classes_ = self.le_.transform(self.freqs_) + for i, k in zip(self.freqs_, self.le_.transform(self.freqs_)): + self.one_hot_[i] = k + n_trials, n_channels, n_times = len(X), X.info["nchan"], len(X.times) # Whiten signal in parallel if self.n_jobs == 1: @@ -644,14 +673,24 @@ def fit(self, X, y, sample_weight=None): Z = W.transpose((0, 2, 1)) @ X_white # Get Ym - self.Ym = dict() for m_class in self.classes_: self.Ym[m_class] = Z[y == m_class].transpose(2, 0, 1).reshape(-1, n_times) return self def predict(self, X): - """Predict is made by taking the maximum correlation coefficient.""" + """Predict is made by taking the maximum correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + y : list of int + Predicted labels. + """ # Check is fit had been called check_is_fitted(self) @@ -662,11 +701,22 @@ def predict(self, X): for f in self.classes_: S_x, S_y = self.cca.fit_transform(x.T, self.Ym[f].T) corr_f[f] = np.corrcoef(S_x.T, S_y.T)[0, 1] - y.append(self.one_hot[max(corr_f, key=corr_f.get)]) + y.append(max(corr_f, key=corr_f.get)) return y def predict_proba(self, X): - """Probability could be computed from the correlation coefficient.""" + """Probability could be computed from the correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + P : ndarray of shape (n_trials, n_classes) + Probability of each class for each trial. + """ # Check is fit had been called check_is_fitted(self) diff --git a/pipelines/MsetCCA-SSVEP.yml b/pipelines/MsetCCA-SSVEP.yml index cc4c8dc45..c49db2bf7 100644 --- a/pipelines/MsetCCA-SSVEP.yml +++ b/pipelines/MsetCCA-SSVEP.yml @@ -1,4 +1,5 @@ name: MsetCCA-SSVEP + paradigms: - SSVEP @@ -9,4 +10,5 @@ pipeline: - name: SSVEP_MsetCCA from: moabb.pipelines.classification parameters: - freqs: {"13":2, "17":3, "21":4} + n_filters: 1 + n_jobs: 1 From ea0ed1fdb1a9066bbb91ef34ea79510be350d538 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Wed, 12 Jun 2024 01:40:01 +0200 Subject: [PATCH 12/17] fix: remove debug print --- moabb/pipelines/classification.py | 1 - 1 file changed, 1 deletion(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index abf7a59d1..db04fcebd 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -404,7 +404,6 @@ def fit(self, X, y): # Filterbank approach for band_n in range(self.n_fbands): - print(self.sfreq_, band_n, self.peaks_) # Filter the data and compute TRCA X_filter = filterbank( X_cal.get_data(copy=False), self.sfreq_, band_n, self.peaks_ From 8055961c0e4c2fc5b47008192e4a94a2435fe405 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Wed, 12 Jun 2024 01:46:30 +0200 Subject: [PATCH 13/17] fix: change error to warning --- moabb/pipelines/classification.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index db04fcebd..1413f7c29 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import scipy.linalg as linalg from joblib import Parallel, delayed @@ -13,6 +15,9 @@ from .utils import filterbank +log = logging.getLogger(__name__) + + class SSVEP_CCA(BaseEstimator, ClassifierMixin): """Classifier based on Canonical Correlation Analysis for SSVEP. @@ -390,9 +395,7 @@ def fit(self, X, y): self.one_hot_[i] = k self.one_inv_[k] = i if self.n_fbands > len(self.peaks_): - raise ValueError( - "Number of filterbank bands should be less or equal to the number of peaks." - ) + log.warning("Try with lower n_fbands if there is an error.") # Initialize the final arrays self.templates_ = np.zeros((self.n_classes, self.n_fbands, n_channels, n_samples)) From 7d90eee75c9aa9fc5ddcf5dda4ef48cd151a5839 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 13 Jun 2024 09:56:19 +0200 Subject: [PATCH 14/17] fix: correct TRCA method --- pipelines/TRCA-SSVEP.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/TRCA-SSVEP.yml b/pipelines/TRCA-SSVEP.yml index 29a1871a3..07938ac94 100644 --- a/pipelines/TRCA-SSVEP.yml +++ b/pipelines/TRCA-SSVEP.yml @@ -11,4 +11,4 @@ pipeline: parameters: n_fbands: 5 is_ensemble: True - method: "riemann" + method: "original" From cc92f618b0170437b6bfa3b572c096da5a212b98 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Fri, 14 Jun 2024 01:55:23 +0200 Subject: [PATCH 15/17] enh: correct MsetCCA tests, add TRCA and CCA tests --- moabb/pipelines/classification.py | 9 ++- moabb/tests/classification.py | 112 ++++++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index 1413f7c29..59133f95a 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -59,7 +59,7 @@ def __init__(self, n_harmonics=3): self.n_harmonics = n_harmonics self.classes_ = [] self.one_hot_ = {} - self.le_ = self.slen_ = self.freqs_ = None + self.le_, self.slen_, self.freqs_ = None, None, [] def fit(self, X, y, sample_weight=None): """Compute reference sinusoid signal. @@ -70,7 +70,10 @@ def fit(self, X, y, sample_weight=None): ---------- X : MNE Epochs The training data as MNE Epochs object. - y : unused, only for compatibility with scikit-learn + y : Unused, + Only for compatibility with scikit-learn + sample_weight : Unused, + Only for compatibility with scikit-learn Returns ------- @@ -616,7 +619,7 @@ def __init__(self, n_filters=1, n_jobs=1): self.n_jobs = n_jobs self.n_filters = n_filters self.cca = CCA(n_components=1) - self.freqs_, self.le_, self.classes_ = None, None, None + self.freqs_, self.le_, self.classes_ = [], None, None self.one_hot_, self.Ym = {}, {} def fit(self, X, y, sample_weight=None): diff --git a/moabb/tests/classification.py b/moabb/tests/classification.py index 6e39dc229..3fc2fb7e5 100644 --- a/moabb/tests/classification.py +++ b/moabb/tests/classification.py @@ -4,7 +4,105 @@ from moabb.datasets.fake import FakeDataset from moabb.paradigms import SSVEP -from moabb.pipelines import SSVEP_MsetCCA +from moabb.pipelines import SSVEP_CCA, SSVEP_TRCA, SSVEP_MsetCCA + + +class TestSSVEP_CCA(unittest.TestCase): + def setUp(self): + # Use moabb generated dataset for test + dataset = FakeDataset(n_sessions=1, n_runs=1, n_subjects=1, paradigm="ssvep") + paradigm = SSVEP(n_classes=3) + X, y, _ = paradigm.get_data(dataset) + self.freqs = paradigm.used_events(dataset) + self.n_harmonics = 3 + self.X = X + self.y = y + self.clf = SSVEP_CCA(n_harmonics=self.n_harmonics) + + def test_fit(self): + self.clf.fit(self.X, self.y) + self.assertTrue(hasattr(self.clf, "freqs_")) + self.assertTrue(hasattr(self.clf, "classes_")) + self.assertTrue(hasattr(self.clf, "le_")) + self.assertTrue(hasattr(self.clf, "one_hot_")) + self.assertTrue(hasattr(self.clf, "slen_")) + + def test_predict(self): + self.clf.fit(self.X, self.y) + y_pred = self.clf.predict(self.X) + self.assertEqual(len(y_pred), len(self.X)) + + def test_predict_proba(self): + self.clf.fit(self.X, self.y) + P = self.clf.predict_proba(self.X) + self.assertEqual(P.shape[0], len(self.X)) + self.assertEqual(P.shape[1], len(self.freqs_)) + + def test_fit_predict_is_fitted(self): + self.assertRaises(NotFittedError, self.clf.predict, self.X) + self.assertRaises(NotFittedError, self.clf.predict_proba, self.X) + self.clf.fit(self.X, self.y) + check_is_fitted( + self.clf, attributes=["classes_", "one_hot_", "slen_", "freqs_", "le_"] + ) + + +class TestSSVEP_TRCA(unittest.TestCase): + def setUp(self): + # Use moabb generated dataset for test + dataset = FakeDataset(n_sessions=1, n_runs=1, n_subjects=1, paradigm="ssvep") + self.n_classes = 3 + paradigm = SSVEP(n_classes=self.n_classes) + X, y, _ = paradigm.get_data(dataset) + self.freqs = paradigm.used_events(dataset) + self.n_fbands = 3 + self.X = X + self.y = y + + def test_fit(self): + for method in ["original", "riemann", "logeuclid"]: + for estimator in ["scm", "lwf", "oas"]: + self.clf = SSVEP_TRCA( + n_fbands=self.n_fbands, method=method, estimator=estimator + ) + self.clf.fit(self.X, self.y) + self.assertTrue(hasattr(self.clf, "freqs_")) + self.assertTrue(hasattr(self.clf, "peaks_")) + self.assertTrue(hasattr(self.clf, "classes_")) + self.assertTrue(hasattr(self.clf, "n_classes")) + self.assertTrue(hasattr(self.clf, "le_")) + self.assertTrue(hasattr(self.clf, "one_hot_")) + self.assertTrue(hasattr(self.clf, "one_inv_")) + self.assertTrue(hasattr(self.clf, "sfreq_")) + + def test_predict(self): + self.clf.fit(self.X, self.y) + y_pred = self.clf.predict(self.X) + self.assertEqual(len(y_pred), len(self.X)) + + def test_predict_proba(self): + self.clf.fit(self.X, self.y) + P = self.clf.predict_proba(self.X) + self.assertEqual(P.shape[0], len(self.X)) + self.assertEqual(P.shape[1], len(self.n_classes)) + + def test_fit_predict_is_fitted(self): + self.assertRaises(NotFittedError, self.clf.predict, self.X) + self.assertRaises(NotFittedError, self.clf.predict_proba, self.X) + self.clf.fit(self.X, self.y) + check_is_fitted( + self.clf, + attributes=[ + "classes_", + "n_classes", + "peaks_", + "one_hot_", + "one_inv_", + "freqs_", + "le_", + "sfreq_", + ], + ) class TestSSVEP_MsetCCA(unittest.TestCase): @@ -17,12 +115,14 @@ def setUp(self): self.n_filters = 2 self.X = X self.y = y - self.clf = SSVEP_MsetCCA(freqs=self.freqs, n_filters=self.n_filters) + self.clf = SSVEP_MsetCCA(n_filters=self.n_filters) def test_fit(self): self.clf.fit(self.X, self.y) + self.assertTrue(hasattr(self.clf, "freqs_")) self.assertTrue(hasattr(self.clf, "classes_")) - self.assertTrue(hasattr(self.clf, "one_hot")) + self.assertTrue(hasattr(self.clf, "le_")) + self.assertTrue(hasattr(self.clf, "one_hot_")) self.assertTrue(hasattr(self.clf, "Ym")) def test_predict(self): @@ -34,13 +134,15 @@ def test_predict_proba(self): self.clf.fit(self.X, self.y) P = self.clf.predict_proba(self.X) self.assertEqual(P.shape[0], len(self.X)) - self.assertEqual(P.shape[1], len(self.freqs)) + self.assertEqual(P.shape[1], len(self.classes_)) def test_fit_predict_is_fitted(self): self.assertRaises(NotFittedError, self.clf.predict, self.X) self.assertRaises(NotFittedError, self.clf.predict_proba, self.X) self.clf.fit(self.X, self.y) - check_is_fitted(self.clf, attributes=["classes_", "one_hot", "Ym"]) + check_is_fitted( + self.clf, attributes=["classes_", "one_hot_", "Ym", "freqs_", "le_"] + ) if __name__ == "__main__": From 195c10f66e7ff36ae6744e033b7afc42851fec96 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Fri, 14 Jun 2024 01:57:35 +0200 Subject: [PATCH 16/17] fix: correct test CCA pipeline --- moabb/tests/test_pipelines/SSVEP_CCA.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/moabb/tests/test_pipelines/SSVEP_CCA.yml b/moabb/tests/test_pipelines/SSVEP_CCA.yml index cdf712018..09736c41b 100644 --- a/moabb/tests/test_pipelines/SSVEP_CCA.yml +++ b/moabb/tests/test_pipelines/SSVEP_CCA.yml @@ -9,4 +9,3 @@ pipeline: parameters: n_harmonics: 3 interval: [1, 3] - freqs: {"13":0, "17":1} From 7faccff23c77fa9e5966c52c5c3927831cd8de76 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Fri, 14 Jun 2024 01:59:10 +0200 Subject: [PATCH 17/17] fix: correct test CCA pipeline --- moabb/tests/test_pipelines/SSVEP_CCA.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/moabb/tests/test_pipelines/SSVEP_CCA.yml b/moabb/tests/test_pipelines/SSVEP_CCA.yml index 09736c41b..9037fe763 100644 --- a/moabb/tests/test_pipelines/SSVEP_CCA.yml +++ b/moabb/tests/test_pipelines/SSVEP_CCA.yml @@ -8,4 +8,3 @@ pipeline: from: moabb.pipelines.classification parameters: n_harmonics: 3 - interval: [1, 3]