diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index bc52bda1e..046bd427c 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -260,10 +260,25 @@ def _evaluate_subject( if isinstance(X, BaseEpochs): scorer = get_scorer(self.paradigm.scoring) acc = list() - X_ = X[ix] - y_ = y[ix] if self.mne_labels else y_cv for cv_ind, (train, test) in enumerate(cv.split(X_, y_)): cvclf = clone(grid_clf) + + valid_split = 0.2 + for s_cvclf in cvclf.steps: + if "validation_split" in s_cvclf[1]._get_param_names(): + valid_split = s_cvclf[1].validation_split + valid_split = 1.0 - np.floor( + len(train) * (1.0 - valid_split) + ) / len(train) + + cv_tv = StratifiedKFold( + int(np.round(1.0 / valid_split)), + shuffle=True, + random_state=self.random_state, + ) + (train_t, train_v) = next(cv_tv.split(X_[train], y_[train])) + train = np.concatenate((train[train_t], train[train_v]), axis=0) + cvclf.fit(X_[train], y_[train]) acc.append(scorer(cvclf, X_[test], y_[test]))