Skip to content

Commit

Permalink
Merge branch 'Adding-ERPCore' of https://github.com/tahatt13/moabb in…
Browse files Browse the repository at this point in the history
…to Adding-ERPCore
  • Loading branch information
tahatt13 committed Aug 29, 2024
2 parents dd79a26 + b44e27b commit 07492af
Show file tree
Hide file tree
Showing 16 changed files with 1,300 additions and 568 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ jobs:

- name: Install dependencies
if: (steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --with docs --extras deeplearning
run: poetry install --no-interaction --no-root --with docs --extras deeplearning --extras optuna

- name: Install library
run: poetry install --no-interaction --with docs --extras deeplearning
run: poetry install --no-interaction --with docs --extras deeplearning --extras optuna

- name: Build docs
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test-devel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ jobs:
if: |
(runner.os != 'Windows') &&
(steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --extras deeplearning
run: poetry install --no-interaction --no-root --extras deeplearning --extras optuna

- name: Install library (Linux/OSX)
if: ${{ runner.os != 'Windows' }}
run: poetry install --no-interaction --extras deeplearning
run: poetry install --no-interaction --extras deeplearning --extras optuna

- name: Install library (Windows)
if: ${{ runner.os == 'Windows' }}
run: poetry install --no-interaction
run: poetry install --no-interaction --extras optuna

- name: Run tests
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ jobs:
if: |
(runner.os != 'Windows') &&
(steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --extras deeplearning
run: poetry install --no-interaction --no-root --extras deeplearning --extras optuna

- name: Install library (Linux/OSX)
if: ${{ runner.os != 'Windows' }}
run: poetry install --no-interaction --extras deeplearning
run: poetry install --no-interaction --extras deeplearning --extras optuna

- name: Install library (Windows)
if: ${{ runner.os == 'Windows' }}
run: poetry install --no-interaction
run: poetry install --no-interaction --extras optuna

- name: Run tests
run: |
Expand Down
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ exclude: ".*svg"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-yaml
- id: check-json
Expand All @@ -35,14 +35,14 @@ repos:


- repo: https://github.com/psf/black
rev: 24.3.0
rev: 24.4.2
hooks:
- id: black
language_version: python3
args: [ --line-length=90, --target-version=py38 ]

- repo: https://github.com/asottile/blacken-docs
rev: 1.16.0
rev: 1.18.0
hooks:
- id: blacken-docs
additional_dependencies: [black==23.3.0]
Expand All @@ -54,7 +54,7 @@ repos:
- id: isort

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.1.0
hooks:
- id: flake8
additional_dependencies: [
Expand All @@ -69,17 +69,17 @@ repos:
exclude: ^docs/ | ^setup\.py$ |

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
rev: v0.5.0
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix, --ignore, E501 ]

- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
args:
- --ignore-words-list=additionals,alle,alot,bund,currenty,datas,farenheit,falsy,fo,haa,hass,iif,incomfort,ines,ist,nam,nd,pres,pullrequests,resset,rime,ser,serie,te,technik,ue,unsecure,withing,zar,crate
- --ignore-words-list=assertIn,additionals,alle,alot,bund,currenty,datas,farenheit,falsy,fo,haa,hass,iif,incomfort,ines,ist,nam,nd,pres,pullrequests,resset,rime,ser,serie,te,technik,ue,unsecure,withing,zar,crate
- --skip="./.*,*.csv,*.json,*.ambr"
- --quiet-level=2
exclude_types: [ csv, json, svg, pdf ]
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ authors:
- family-names: "Bjareholt"
given-names: "Erik"
orcid: "https://orcid.org/0000-0003-1350-9677"
- family-names: "Quentin"
given-names: "Barthelemy"
- family-names: "Barthelemy"
given-names: "Quentin"
orcid: "https://orcid.org/0000-0002-7059-6028"
- family-names: "Schirrmeister"
given-names: "Robin Tibor"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ If you use MOABB in your experiments, please cite this library when
publishing a paper to increase the visibility of open science initiatives:

```
Aristimunha, B., Carrara, I., Guetschel, P., Sedlar, S., Rodrigues, P., Sosulski, J., Narayanan, D., Bjareholt, E., Quentin, B., Schirrmeister, R. T.,Kalunga, E., Darmet, L., Gregoire, C., Abdul Hussain, A., Gatti, R., Goncharenko, V., Thielen, J., Moreau, T., Roy, Y., Jayaram, V., Barachant,A., & Chevallier, S.
Aristimunha, B., Carrara, I., Guetschel, P., Sedlar, S., Rodrigues, P., Sosulski, J., Narayanan, D., Bjareholt, E., Barthelemy, Q., Schirrmeister, R. T.,Kalunga, E., Darmet, L., Gregoire, C., Abdul Hussain, A., Gatti, R., Goncharenko, V., Thielen, J., Moreau, T., Roy, Y., Jayaram, V., Barachant,A., & Chevallier, S.
Mother of all BCI Benchmarks (MOABB), 2023. DOI: 10.5281/zenodo.10034223.
```
and here is the Bibtex version:
```bibtex
@software{Aristimunha_Mother_of_all_2023,
author = {Aristimunha, Bruno and Carrara, Igor and Guetschel, Pierre and Sedlar, Sara and Rodrigues, Pedro and Sosulski, Jan and Narayanan, Divyesh and Bjareholt, Erik and Quentin, Barthelemy and Schirrmeister, Robin Tibor and Kalunga, Emmanuel and Darmet, Ludovic and Gregoire, Cattan and Abdul Hussain, Ali and Gatti, Ramiro and Goncharenko, Vladislav and Thielen, Jordy and Moreau, Thomas and Roy, Yannick and Jayaram, Vinay and Barachant, Alexandre and Chevallier, Sylvain},
author = {Aristimunha, Bruno and Carrara, Igor and Guetschel, Pierre and Sedlar, Sara and Rodrigues, Pedro and Sosulski, Jan and Narayanan, Divyesh and Bjareholt, Erik and Barthelemy, Quentin and Schirrmeister, Robin Tibor and Kalunga, Emmanuel and Darmet, Ludovic and Gregoire, Cattan and Abdul Hussain, Ali and Gatti, Ramiro and Goncharenko, Vladislav and Thielen, Jordy and Moreau, Thomas and Roy, Yannick and Jayaram, Vinay and Barachant, Alexandre and Chevallier, Sylvain},
doi = {10.5281/zenodo.10034223},
title = {{Mother of all BCI Benchmarks}},
url = {https://github.com/NeuroTechX/moabb},
Expand Down
9 changes: 8 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@ Develop branch

Enhancements
~~~~~~~~~~~~
- Add possibility to use OptunaGridSearch (:gh:`630` by `Igor Carrara`_)
- Add scripts to upload results on PapersWithCode (:gh:`561` by `Pierre Guetschel`_)
- Centralize dataset summary tables in CSV files (:gh:`635` by `Pierre Guetschel`_)
- Add new dataset :class:`moabb.datasets.Liu2024` dataset (:gh:`619` by `Taha Habib`_)
- Add ERP CORE datasets :class:`moabb.datasets.ErpCore2021` dataset (:gh:`627` by `Taha Habib`_)
- Increasing the version in the pre-commit config (:gh:`631` by pre-commit bot)



Bugs
~~~~
- Fix caching in the workflows (:gh:`632` by `Pierre Guetschel`_)

API changes
~~~~~~~~~~~
- None
- Include optuna as soft-dependency in the benchmark function and in the base of evaluation (:gh:`630` by `Igor Carrara`_)



Version - 1.1.0 (Stable - PyPi)
---------------------------------
Expand Down
8 changes: 8 additions & 0 deletions moabb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def benchmark( # noqa: C901
exclude_datasets=None,
n_splits=None,
cache_config=None,
optuna=False,
):
"""Run benchmarks for selected pipelines and datasets.
Expand Down Expand Up @@ -102,6 +103,7 @@ def benchmark( # noqa: C901
and exclude_datasets are specified, raise an error.
exclude_datasets: list of str or Dataset object
Datasets to exclude from the benchmark run
optuna: Enable Optuna for the hyperparameter search
Returns
-------
Expand All @@ -110,7 +112,11 @@ def benchmark( # noqa: C901
Notes
-----
.. versionadded:: 1.1.1
Includes the possibility to use Optuna for hyperparameter search.
.. versionadded:: 0.5.0
Create the function to run the benchmark
"""
# set logs
if evaluations is None:
Expand Down Expand Up @@ -182,6 +188,7 @@ def benchmark( # noqa: C901
return_epochs=True,
n_splits=n_splits,
cache_config=cache_config,
optuna=optuna,
)
paradigm_results = context.process(
pipelines=ppl_with_epochs, param_grid=param_grid
Expand All @@ -202,6 +209,7 @@ def benchmark( # noqa: C901
overwrite=overwrite,
n_splits=n_splits,
cache_config=cache_config,
optuna=optuna,
)
paradigm_results = context.process(
pipelines=ppl_with_array, param_grid=param_grid
Expand Down
47 changes: 45 additions & 2 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import logging
from abc import ABC, abstractmethod
from warnings import warn

import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV

from moabb.analysis import Results
from moabb.datasets.base import BaseDataset
from moabb.evaluations.utils import _convert_sklearn_params_to_optuna
from moabb.paradigms.base import BaseParadigm


log = logging.getLogger(__name__)

# Making the optuna soft dependency
try:
from optuna.integration import OptunaSearchCV

optuna_available = True
except ImportError:
optuna_available = False

if optuna_available:
search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV}
else:
search_methods = {"grid": GridSearchCV}


class BaseEvaluation(ABC):
"""Base class that defines necessary operations for an evaluation.
Expand Down Expand Up @@ -53,11 +68,19 @@ class BaseEvaluation(ABC):
Save model after training, for each fold of cross-validation if needed
cache_config: bool, default=None
Configuration for caching of datasets. See :class:`moabb.datasets.base.CacheConfig` for details.
optuna:bool, default=False
If optuna is enable it will change the GridSearch to a RandomizedGridSearch with 15 minutes of cut off time.
This option is compatible with list of entries of type None, bool, int, float and string
time_out: default=60*15
Cut off time for the optuna search expressed in seconds, the default value is 15 minutes.
Only used with optuna equal to True.
Notes
-----
.. versionadded:: 1.1.0
n_splits, save_model, cache_config parameters.
.. versionadded:: 1.1.1
optuna, time_out parameters.
"""

def __init__(
Expand All @@ -77,6 +100,8 @@ def __init__(
n_splits=None,
save_model=False,
cache_config=None,
optuna=False,
time_out=60 * 15,
):
self.random_state = random_state
self.n_jobs = n_jobs
Expand All @@ -88,6 +113,16 @@ def __init__(
self.n_splits = n_splits
self.save_model = save_model
self.cache_config = cache_config
self.optuna = optuna
self.time_out = time_out

if self.optuna and not optuna_available:
raise ImportError("Optuna is not available. Please install it first.")
if (self.time_out != 60 * 15) and not self.optuna:
warn(
"time_out parameter is only used when optuna is enabled. "
"Ignoring time_out parameter."
)
# check paradigm
if not isinstance(paradigm, BaseParadigm):
raise (ValueError("paradigm must be an Paradigm instance"))
Expand Down Expand Up @@ -261,19 +296,27 @@ def is_valid(self, dataset):
"""

def _grid_search(self, param_grid, name, grid_clf, inner_cv):
extra_params = {}
if param_grid is not None:
if name in param_grid:
search = GridSearchCV(
if self.optuna:
search = search_methods["optuna"]
param_grid[name] = _convert_sklearn_params_to_optuna(param_grid[name])
extra_params["timeout"] = self.time_out
else:
search = search_methods["grid"]

search = search(
grid_clf,
param_grid[name],
refit=True,
cv=inner_cv,
n_jobs=self.n_jobs,
scoring=self.paradigm.scoring,
return_train_score=True,
**extra_params,
)
return search

else:
return grid_clf

Expand Down
41 changes: 41 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from sklearn.pipeline import Pipeline


try:
from optuna.distributions import CategoricalDistribution

optuna_available = True
except ImportError:
optuna_available = False


def _check_if_is_keras_model(model):
"""Check if the model is a Keras model.
Expand Down Expand Up @@ -212,3 +220,36 @@ def create_save_path(
return str(path_save)
else:
print("No hdf5_path provided, models will not be saved.")


def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict:
"""
Function to convert the parameter in Optuna format. This function will
create a categorical distribution of values from the list of values
provided in the parameter grid.
Parameters
----------
param_grid:
Dictionary with the parameters to be converted.
Returns
-------
optuna_params: dict
Dictionary with the parameters converted to Optuna format.
"""
if not optuna_available:
raise ImportError(
"Optuna is not available. Please install it optuna " "and optuna-integration."
)
else:
optuna_params = {}
for key, value in param_grid.items():
try:
if isinstance(value, list):
optuna_params[key] = CategoricalDistribution(value)
else:
optuna_params[key] = value
except Exception as e:
raise ValueError(f"Conversion failed for parameter {key}: {e}")
return optuna_params
10 changes: 10 additions & 0 deletions moabb/tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def test_include_exclude(self):
overwrite=True,
)

def test_optuna(self):
res = benchmark(
pipelines=str(self.pp_dir),
evaluations=["WithinSession"],
paradigms=["FakeImageryParadigm"],
overwrite=True,
optuna=True,
)
self.assertEqual(len(res), 40)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 07492af

Please sign in to comment.