Skip to content

Commit

Permalink
add: exapmle of terminator in lgbm
Browse files Browse the repository at this point in the history
  • Loading branch information
Keita-S593 committed Sep 30, 2023
1 parent e61deda commit 96dbd01
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
55 changes: 55 additions & 0 deletions lightgbm/lightgbm_terminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Optuna example showcasing the new Optuna Terminator feature.
In this example, we utilize the Optuna Terminator for hyperparameter
optimization on a lightgbm using the wine dataset.
The Terminator automatically stops the optimization process based
on the potential for further improvement.
To run this example:
$ python lightgbm_terminator.py
"""

import optuna
from optuna.terminator.callback import TerminatorCallback
from optuna.terminator.erroreval import report_cross_validation_scores

from sklearn.datasets import load_wine
import lightgbm as lgb
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold


def objective(trial):
X, y = load_wine(return_X_y=True)

params = {
"objective": "multiclass",
"num_class": 3,
"verbosity": -1,
"boosting_type": "gbdt",
"lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
"num_leaves": trial.suggest_int("num_leaves", 2, 256),
"feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
}

clf = lgb.LGBMClassifier(**params)

scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
report_cross_validation_scores(trial, scores)

return scores.mean()


if __name__ == "__main__":
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50, callbacks=[TerminatorCallback()])

print(f"The number of trials: {len(study.trials)}")
print(f"Best value: {study.best_value} (params: {study.best_params})")
1 change: 1 addition & 0 deletions lightgbm/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ lightgbm>=3.3.0
numpy
optuna
scikit-learn
botorch

0 comments on commit 96dbd01

Please sign in to comment.