Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optuna GridSearch #630

Merged
merged 31 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7119ec5
Optuna
carraraig Jun 27, 2024
c09cb7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
d4cb28e
Optuna - Categorical Distibution
carraraig Jun 27, 2024
1767d6a
Merge remote-tracking branch 'origin/Optuna' into Optuna
carraraig Jun 27, 2024
80940d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
bbdf225
Add optuna to dependency and regenerate poetry
carraraig Jun 27, 2024
0ce8943
Merge remote-tracking branch 'origin/Optuna' into Optuna
carraraig Jun 27, 2024
a99d6b8
Add optuna to dependency and regenerate poetry
carraraig Jun 27, 2024
d61e43a
Add optuna to dependency and regenerate poetry
carraraig Jun 27, 2024
0ad7c12
Add test on within Session
carraraig Jun 27, 2024
25e73c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
d38803d
Add test on within Session
carraraig Jun 28, 2024
bfa3c38
Merge remote-tracking branch 'origin/Optuna' into Optuna
carraraig Jun 28, 2024
c04205a
Add test on within Session
carraraig Jun 28, 2024
d1ba14b
ehn: common function with dict
bruAristimunha Jun 28, 2024
6efdaa4
ehn: moving function to util
bruAristimunha Jun 28, 2024
ae6be13
fix: correcting the what news file.
bruAristimunha Jul 1, 2024
c3a7daf
Merge branch 'develop' into Optuna
bruAristimunha Jul 15, 2024
6f572f5
Add test benchmark and raise an issue if the conversion didn't worked…
carraraig Jul 15, 2024
c107a2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
b160cf6
FIX: italian to eng
bruAristimunha Jul 15, 2024
dc09e6c
EHN: making optuna optional
bruAristimunha Jul 15, 2024
de72e46
FIX: fixing the workflow files
bruAristimunha Jul 15, 2024
c2810dc
FIX: changing the optuna file
bruAristimunha Jul 15, 2024
19aac40
Merge branch 'develop' into Optuna
bruAristimunha Jul 15, 2024
2224ceb
FIX: including optuna for the windows
bruAristimunha Jul 15, 2024
73d18db
Merge remote-tracking branch 'carraraig/Optuna' into Optuna
bruAristimunha Jul 15, 2024
f1a2262
FIX: fix the doc generation
bruAristimunha Jul 15, 2024
1399706
FIX: make sure to not include a warning in all the executions
bruAristimunha Jul 15, 2024
392cb3e
FIX: fixing the workflow file
bruAristimunha Jul 15, 2024
b3ec191
FIX: fixing the doc
bruAristimunha Jul 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Develop branch

Enhancements
~~~~~~~~~~~~
- None
- - Add possibility to use RandomizedGridSearch (:gh:`630` by `Igor Carrara`_)

Bugs
~~~~
Expand Down
21 changes: 19 additions & 2 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
from abc import ABC, abstractmethod

import pandas as pd
from optuna.integration import OptunaSearchCV
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV

from moabb.analysis import Results
from moabb.datasets.base import BaseDataset
from moabb.evaluations.utils import _convert_sklearn_params_to_optuna
from moabb.paradigms.base import BaseParadigm


log = logging.getLogger(__name__)

search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV}


class BaseEvaluation(ABC):
"""Base class that defines necessary operations for an evaluation.
Expand Down Expand Up @@ -53,6 +57,9 @@ class BaseEvaluation(ABC):
Save model after training, for each fold of cross-validation if needed
cache_config: bool, default=None
Configuration for caching of datasets. See :class:`moabb.datasets.base.CacheConfig` for details.
optuna:bool, default=False
If optuna is enable it will change the GridSearch to a RandomizedGridSearch with 15 minutes of cut off time.
This option is compatible with list of entries of type None, bool, int, float and string

Notes
-----
Expand All @@ -77,6 +84,7 @@ def __init__(
n_splits=None,
save_model=False,
cache_config=None,
optuna=False,
):
self.random_state = random_state
self.n_jobs = n_jobs
Expand All @@ -88,6 +96,7 @@ def __init__(
self.n_splits = n_splits
self.save_model = save_model
self.cache_config = cache_config
self.optuna = optuna
# check paradigm
if not isinstance(paradigm, BaseParadigm):
raise (ValueError("paradigm must be an Paradigm instance"))
Expand Down Expand Up @@ -261,19 +270,27 @@ def is_valid(self, dataset):
"""

def _grid_search(self, param_grid, name, grid_clf, inner_cv):
extra_params = {}
if param_grid is not None:
if name in param_grid:
search = GridSearchCV(
if self.optuna:
search = search_methods["optuna"]
param_grid[name] = _convert_sklearn_params_to_optuna(param_grid[name])
extra_params["timeout"] = 60 * 15 # 15 minutes
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
else:
search = search_methods["grid"]

search = search(
grid_clf,
param_grid[name],
refit=True,
cv=inner_cv,
n_jobs=self.n_jobs,
scoring=self.paradigm.scoring,
return_train_score=True,
**extra_params,
)
return search

else:
return grid_clf

Expand Down
22 changes: 22 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pickle import HIGHEST_PROTOCOL, dump
from typing import Sequence

import optuna
from numpy import argmax
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -212,3 +213,24 @@ def create_save_path(
return str(path_save)
else:
print("No hdf5_path provided, models will not be saved.")


def _convert_sklearn_params_to_optuna(param_grid):
"""
Function to convert the parameter in Optuna format. This function will create a integer distribution of values
between the max and minimum value of the parameter.
Parameters
----------
param_grid

Returns
-------

"""
optuna_params = {}
for key, value in param_grid.items():
if isinstance(value, list):
optuna_params[key] = optuna.distributions.CategoricalDistribution(value)
else:
optuna_params[key] = value
return optuna_params
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 25 additions & 0 deletions moabb/tests/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def setUp(self):
datasets=[dataset],
hdf5_path="res_test",
save_model=True,
optuna=False,
)

def test_mne_labels(self):
Expand Down Expand Up @@ -138,6 +139,30 @@ def test_eval_grid_search(self):
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_eval_grid_search_optuna(self):
# Test grid search
param_grid = {"C": {"csp__metric": ["euclid", "riemann"]}}
process_pipeline = self.eval.paradigm.make_process_pipelines(dataset)[0]

self.eval.optuna = True

results = [
r
for r in self.eval.evaluate(
dataset,
pipelines,
param_grid=param_grid,
process_pipeline=process_pipeline,
)
]

self.eval.optuna = False

# We should get 4 results, 2 sessions 2 subjects
self.assertEqual(len(results), 4)
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_within_session_evaluation_save_model(self):
res_test_path = "./res_test"

Expand Down
Loading