From a9e09e3f8cd36703d29a1900c58a5b62fba821b4 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Fri, 30 Aug 2024 14:55:41 +0200 Subject: [PATCH 1/6] Remove final model fit requirement for inspector --- julearn/api.py | 11 +++++++---- julearn/inspect/tests/test_inspector.py | 23 ++++++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/julearn/api.py b/julearn/api.py index 130b57e0a..17fe483a1 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -194,11 +194,11 @@ def run_cross_validation( # noqa: C901 ) if return_inspector: if return_estimator is None: - logger.info("Inspector requested: setting return_estimator='all'") return_estimator = "all" - if return_estimator != "all": + if return_estimator not in ["all", "cv"]: raise_error( - "return_inspector=True requires return_estimator to be `all`." + "return_inspector=True requires return_estimator to be `all` " + "or `cv`" ) X_types = {} if X_types is None else X_types @@ -441,6 +441,9 @@ def run_cross_validation( # noqa: C901 groups=df_groups, cv=cv_outer, ) - out = scores_df, pipeline, inspector + if isinstance(out, tuple): + out = (*out, inspector) + else: + out = out, inspector return out diff --git a/julearn/inspect/tests/test_inspector.py b/julearn/inspect/tests/test_inspector.py index 8643cee1d..6f069324b 100644 --- a/julearn/inspect/tests/test_inspector.py +++ b/julearn/inspect/tests/test_inspector.py @@ -54,7 +54,9 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: """ X = list(df_iris.iloc[:, :-1].columns) - scores, pipe, inspect = run_cross_validation( + + # All estimators + out = run_cross_validation( X=X, y="species", data=df_iris, @@ -63,6 +65,7 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: return_inspector=True, problem_type="classification", ) + scores, pipe, inspect = out assert pipe == inspect.model._model # type: ignore for (_, score), inspect_fold in zip( scores.iterrows(), # type: ignore @@ -70,6 +73,24 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: ): assert score["estimator"] == inspect_fold.model._model + del pipe + # only CV estimators + out = run_cross_validation( + X=X, + y="species", + data=df_iris, + model="svm", + return_estimator="cv", + return_inspector=True, + problem_type="classification", + ) + scores, inspect = out + for (_, score), inspect_fold in zip( + scores.iterrows(), # type: ignore + inspect.folds, # type: ignore + ): + assert score["estimator"] == inspect_fold.model._model + def test_normal_usage_with_search(df_iris: "pd.DataFrame") -> None: """Test inspector with search. From 89474bd83b19ce3ea2688132bd9668dabc43604e Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Sat, 31 Aug 2024 11:02:07 +0200 Subject: [PATCH 2/6] Add doc --- docs/changes/newsfragments/270.enh | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/changes/newsfragments/270.enh diff --git a/docs/changes/newsfragments/270.enh b/docs/changes/newsfragments/270.enh new file mode 100644 index 000000000..7cdf2a09f --- /dev/null +++ b/docs/changes/newsfragments/270.enh @@ -0,0 +1 @@ +Remove final model fit requirement for inspector to be returned by `run_cross_validation` by `Fede Raimondo`_. \ No newline at end of file From 28d2fd02a1af7928ce0f6d8d9120678aae9599c0 Mon Sep 17 00:00:00 2001 From: Fede Date: Tue, 3 Sep 2024 15:41:23 +0200 Subject: [PATCH 3/6] Fix change log --- docs/changes/newsfragments/270.enh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changes/newsfragments/270.enh b/docs/changes/newsfragments/270.enh index 7cdf2a09f..974664b2b 100644 --- a/docs/changes/newsfragments/270.enh +++ b/docs/changes/newsfragments/270.enh @@ -1 +1 @@ -Remove final model fit requirement for inspector to be returned by `run_cross_validation` by `Fede Raimondo`_. \ No newline at end of file +Remove final model fit requirement for inspector to be returned by :func:`run_cross_validation` by `Fede Raimondo`_. \ No newline at end of file From 2d878d2c3cec8cb946772a213e823cfca80cfc08 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Wed, 4 Sep 2024 10:43:32 +0200 Subject: [PATCH 4/6] fix documentation --- docs/changes/newsfragments/270.enh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changes/newsfragments/270.enh b/docs/changes/newsfragments/270.enh index 974664b2b..63cd68890 100644 --- a/docs/changes/newsfragments/270.enh +++ b/docs/changes/newsfragments/270.enh @@ -1 +1 @@ -Remove final model fit requirement for inspector to be returned by :func:`run_cross_validation` by `Fede Raimondo`_. \ No newline at end of file +Remove final model fit requirement for inspector to be returned by :func:`.run_cross_validation` by `Fede Raimondo`_. \ No newline at end of file From b2c5e60a1d8c7c776173c00226abe5a0c360df0a Mon Sep 17 00:00:00 2001 From: Fede Date: Wed, 4 Sep 2024 12:19:36 +0200 Subject: [PATCH 5/6] Fix documentation on OptunaSearchCV --- docs/getting_started.rst | 2 +- examples/99_docs/run_hyperparameters_docs.py | 4 ++-- julearn/api.py | 2 +- julearn/pipeline/pipeline_creator.py | 2 +- pyproject.toml | 4 ++-- tox.ini | 6 +++--- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index aa7aa851c..47bd86601 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -96,5 +96,5 @@ The following optional dependencies are available: module is not compatible with newer Python versions and it is unmaintained. * ``skopt``: Using the ``"bayes"`` searcher (:class:`~skopt.BayesSearchCV`) requires the `scikit-optimize`_ package. -* ``optuna``: Using the ``"optuna"`` searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`) requires the `Optuna`_ and `optuna_integration`_ packages. +* ``optuna``: Using the ``"optuna"`` searcher (:class:`~optuna_integration.OptunaSearchCV`) requires the `Optuna`_ and `optuna_integration`_ packages. * ``all``: Install all optional functional dependencies (except ``deslib``). diff --git a/examples/99_docs/run_hyperparameters_docs.py b/examples/99_docs/run_hyperparameters_docs.py index ce476262a..04bdf6a13 100644 --- a/examples/99_docs/run_hyperparameters_docs.py +++ b/examples/99_docs/run_hyperparameters_docs.py @@ -255,7 +255,7 @@ # Other searchers that ``julearn`` provides are the # :class:`~sklearn.model_selection.RandomizedSearchCV`, # :class:`~skopt.BayesSearchCV` and -# :class:`~optuna_integration.sklearn.OptunaSearchCV`. +# :class:`~optuna_integration.OptunaSearchCV`. # # The randomized searcher # (:class:`~sklearn.model_selection.RandomizedSearchCV`) is similar to the @@ -275,7 +275,7 @@ # :class:`~skopt.BayesSearchCV` documentation, including how to specify # the prior distributions of the hyperparameters. # -# The Optuna searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`) +# The Optuna searcher (:class:`~optuna_integration.OptunaSearchCV`) # uses the Optuna library to find the best hyperparameter set. Optuna is a # hyperparameter optimization framework that has several algorithms to find # the best hyperparameter set. For more information, see the diff --git a/julearn/api.py b/julearn/api.py index 17fe483a1..c86086d36 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -142,7 +142,7 @@ def run_cross_validation( # noqa: C901 :class:`~sklearn.model_selection.RandomizedSearchCV` * ``"bayes"`` : :class:`~skopt.BayesSearchCV` * ``"optuna"`` : - :class:`~optuna_integration.sklearn.OptunaSearchCV` + :class:`~optuna_integration.OptunaSearchCV` * user-registered searcher name : see :func:`~julearn.model_selection.register_searcher` * ``scikit-learn``-compatible searcher diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index 60d0be052..9cc9e2fd4 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -944,7 +944,7 @@ def _prepare_hyperparameter_tuning( :class:`~sklearn.model_selection.RandomizedSearchCV` * ``"bayes"`` : :class:`~skopt.BayesSearchCV` * ``"optuna"`` : - :class:`~optuna_integration.sklearn.OptunaSearchCV` + :class:`~optuna_integration.OptunaSearchCV` * user-registered searcher name : see :func:`~julearn.model_selection.register_searcher` * ``scikit-learn``-compatible searcher diff --git a/pyproject.toml b/pyproject.toml index 9615834bf..bcb3b73bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ docs = [ "towncrier<24", "scikit-optimize>=0.10.0,<0.11", "optuna>=3.6.0,<3.7", - "optuna_integration>=3.6.0,<3.7", + "optuna_integration>=3.6.0,<4.1", ] deslib = ["deslib>=0.3.5,<0.4"] viz = [ @@ -72,7 +72,7 @@ viz = [ skopt = ["scikit-optimize>=0.10.0,<0.11"] optuna = [ "optuna>=3.6.0,<3.7", - "optuna_integration>=3.6.0,<3.7", + "optuna_integration>=3.6.0,<4.1", ] # Add all optional functional dependencies (skip deslib until its fixed) # This does not include dev/docs building dependencies diff --git a/tox.ini b/tox.ini index 09bb9c0aa..89030b7b4 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,7 @@ deps = seaborn scikit-optimize>=0.10.0,<0.11 optuna>=3.6.0,<3.7 - optuna_integration>=3.6.0,<3.7 + optuna_integration>=3.6.0,<4.1 commands = pytest {toxinidir}/julearn @@ -45,7 +45,7 @@ deps = param scikit-optimize>=0.10.0,<0.11 optuna>=3.6.0,<3.7 - optuna_integration>=3.6.0,<3.7 + optuna_integration>=3.6.0,<4.1 commands = pytest -vv {toxinidir}/julearn @@ -69,7 +69,7 @@ deps = param scikit-optimize>=0.10.0,<0.11 optuna>=3.6.0,<3.7 - optuna_integration>=3.6.0,<3.7 + optuna_integration>=3.6.0,<4.1 commands = pytest --cov={envsitepackagesdir}/julearn --cov=./julearn --cov-report=xml --cov-report=term -vv From 5e1e40793b25a95e39ba23f24a1acd0806b19b88 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Wed, 4 Sep 2024 13:01:44 +0200 Subject: [PATCH 6/6] Update old what's new txt --- docs/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/whats_new.rst b/docs/whats_new.rst index d1137d51d..dd36ea0c1 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -56,7 +56,7 @@ Enhancements Features ^^^^^^^^ -- Add :class:`~optuna_integration.sklearn.OptunaSearchCV` to the list of +- Add :class:`~optuna_integration.OptunaSearchCV` to the list of available searchers as ``optuna`` by `Fede Raimondo`_ (:gh:`262`)