Skip to content

Commit

Permalink
FIX: Argh
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jan 17, 2025
1 parent 88f7b63 commit 1918de3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/decoding/linear_model_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

# Extract and plot spatial filters and spatial patterns
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
# We fitted the linear model onto Z-scored data. To make the filters
# We fit the linear model on Z-scored data. To make the filters
# interpretable, we must reverse this normalization step
coef = scaler.inverse_transform([coef])[0]

Expand Down
18 changes: 14 additions & 4 deletions mne/decoding/tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

from sklearn.decomposition import PCA
from sklearn.kernel_ridge import KernelRidge
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks

from mne import Epochs, create_info, io, pick_types, read_events
from mne import Epochs, EpochsArray, create_info, io, pick_types, read_events
from mne.decoding import (
FilterEstimator,
LinearModel,
PSDEstimator,
Scaler,
TemporalFilter,
Expand Down Expand Up @@ -218,9 +221,16 @@ def test_vectorizer():
assert_equal(vect.fit_transform(data[1:]).shape, (149, 108))

# check if raised errors are working correctly
vect.fit(np.random.rand(105, 12, 3))
pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1))
pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12))
X = np.random.default_rng(0).standard_normal((105, 12, 3))
y = np.arange(X.shape[0]) % 2
pytest.raises(ValueError, vect.transform, X[..., np.newaxis])
pytest.raises(ValueError, vect.inverse_transform, X[:, :-1])

# And that pipelines work properly
X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg"))
vect.fit(X_arr)
clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel())
clf.fit(X_arr, y)


def test_unsupervised_spatial_filter():
Expand Down
7 changes: 7 additions & 0 deletions mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
pick_info,
)
from ..cov import _check_scalings_user
from ..epochs import BaseEpochs
from ..filter import filter_data
from ..time_frequency import psd_array_multitaper
from ..utils import _check_option, _validate_type, fill_doc
Expand All @@ -34,6 +35,12 @@ def _check_data(
multi_output=False,
check_n_features=True,
):
# Sklearn calls asarray under the hood which works, but elsewhere they check for
# __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...)
# rather than what they expect (shape (...)). So we explicitly get the NumPy
# array to make everyone happy.
if isinstance(epochs_data, BaseEpochs):
epochs_data = epochs_data.get_data(copy=False)
kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True)
if hasattr(self, "n_features_in_") and check_n_features:
if y is None:
Expand Down

0 comments on commit 1918de3

Please sign in to comment.