Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NOT MERGE including optuna #606

Closed
wants to merge 13 commits into from
86 changes: 84 additions & 2 deletions moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Union

import numpy as np
import optuna
from mne.epochs import BaseEpochs
from sklearn.base import clone
from sklearn.metrics import get_scorer
Expand All @@ -13,13 +14,20 @@
StratifiedKFold,
StratifiedShuffleSplit,
cross_validate,
train_test_split,
)
from sklearn.model_selection._validation import _fit_and_score, _score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

from moabb.evaluations.base import BaseEvaluation
from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list
from moabb.evaluations.utils import (
create_deep_model,
create_save_path,
save_model_cv,
save_model_list,
)


try:
Expand All @@ -35,6 +43,36 @@
Vector = Union[list, tuple, np.ndarray]


def objective(trial, X, y, clf, scorer, epochs, random_state):
learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
weight_decay = trial.suggest_float("weight_decay", 1e-10, 1e-3, log=True)
drop_rate = trial.suggest_float("drop_rate", 0.3, 0.9)

pre_process_steps, model = create_deep_model(
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
clf, learning_rate, weight_decay, drop_rate, epochs=epochs, manual_validation=True
)
n_epochs = list(range(len(y)))
try:
idx_X_train, idx_X_val, y_train, y_val = train_test_split(
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
n_epochs, y, test_size=0.2, stratify=y, random_state=random_state
)
except Exception as e:
print(e)
idx_X_train, idx_X_val, y_train, y_val = train_test_split(
n_epochs, y, test_size=0.2, random_state=random_state
)

X_train = X[idx_X_train]
X_val = X[idx_X_val]
pre_process_steps.fit(X=X_train, y=y_train)
X_val_pre_processed = pre_process_steps.transform(X_val)

model = Pipeline([("preprocessing", pre_process_steps), ("deep", model)])
model.fit(X, y, deep__validation_data=(X_val_pre_processed, y_val))

return scorer(model, X_val, y_val)


class WithinSessionEvaluation(BaseEvaluation):
"""Performance evaluation within session (k-fold cross-validation)

Expand Down Expand Up @@ -94,10 +132,14 @@ def __init__(
self,
n_perms: Optional[Union[int, Vector]] = None,
data_size: Optional[dict] = None,
optuna_n_trials: int = 25,
optuna_timeout: int = 60 * 10,
**kwargs,
):
self.data_size = data_size
self.n_perms = n_perms
self.optuna_n_trials = optuna_n_trials
self.optuna_timeout = optuna_timeout
self.calculate_learning_curve = self.data_size is not None
if self.calculate_learning_curve:
# Check correct n_perms parameter
Expand Down Expand Up @@ -220,8 +262,48 @@ def _evaluate(
y_ = y[ix] if self.mne_labels else y_cv
for cv_ind, (train, test) in enumerate(cv.split(X_, y_)):
cvclf = clone(grid_clf)
cvclf.fit(X_[train], y_[train])
n_epochs = cvclf[-1].epochs

study = optuna.create_study(
direction="maximize",
study_name=f"{name}_{subject}_{session}_{cv_ind}",
)
study.optimize(
lambda trial: objective(
trial,
X=X_[train],
y=y_[train],
clf=cvclf,
scorer=scorer,
epochs=n_epochs,
random_state=self.random_state,
),
n_trials=self.optuna_n_trials,
timeout=self.optuna_timeout, # one hour
show_progress_bar=True,
n_jobs=1,
gc_after_trial=True,
)
best_params = study.best_params

pre_process_steps, model = create_deep_model(
clf=cvclf,
**best_params,
epochs=n_epochs,
)
cvclf = Pipeline(
[("preprocessing", pre_process_steps), ("deep", model)]
)

cvclf = cvclf.fit(X_[train], y_[train])
acc.append(scorer(cvclf, X_[test], y_[test]))
if hasattr(cvclf, "_final_estimator"):
save_history_name = (
f"{name}_{cv_ind}_fold_{session}_session"
)
history = cvclf._final_estimator.history_

np.savez(f"{save_history_name}.npz", history)

if self.hdf5_path is not None and self.save_model:
save_model_cv(
Expand Down
53 changes: 53 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,29 @@
from pickle import HIGHEST_PROTOCOL, dump
from typing import Sequence

import tensorflow as tf
from numpy import argmax
from sklearn.pipeline import Pipeline

from moabb.pipelines.deep_learning import (
KerasDeepConvNet,
KerasEEGITNet,
KerasEEGNet_8_2,
KerasEEGNeX,
KerasEEGTCNet,
KerasShallowConvNet,
)


models_class = {
"KerasShallowConvNet": KerasShallowConvNet,
"KerasDeepConvNet": KerasDeepConvNet,
"KerasEEGNet_8_2": KerasEEGNet_8_2,
"KerasEEGITNet": KerasEEGITNet,
"KerasEEGTCNet": KerasEEGTCNet,
"KerasEEGNeX": KerasEEGNeX,
}


def _check_if_is_keras_model(model):
"""Check if the model is a Keras model.
Expand Down Expand Up @@ -212,3 +232,36 @@ def create_save_path(
return str(path_save)
else:
print("No hdf5_path provided, models will not be saved.")


def create_deep_model(
clf, learning_rate, weight_decay, drop_rate, epochs, manual_validation=False
):
keras_clf = clf[-1]
steps = list(clf.steps)

if manual_validation:
validation_split = (0.0,)
else:
validation_split = keras_clf.validation_split

Adam = tf.keras.optimizers.Adam(
learning_rate=learning_rate, weight_decay=weight_decay
)
new_keras_clf = models_class[keras_clf.__class__.__name__](
loss="sparse_categorical_crossentropy",
optimizer=Adam,
drop_rate=drop_rate,
epochs=epochs,
verbose=keras_clf.verbose,
# rest of the parameters are the same as the original model
callbacks=keras_clf.callbacks,
random_state=keras_clf.random_state,
batch_size=keras_clf.batch_size,
validation_split=validation_split,
shuffle=keras_clf.shuffle,
)
steps[-1] = ("deep", new_keras_clf)
pre_process_steps = Pipeline(steps[:-1])
# pipe = Pipeline(("deep", new_keras_clf))
return pre_process_steps, new_keras_clf
Loading