-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e61deda
commit 96dbd01
Showing
2 changed files
with
56 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
Optuna example showcasing the new Optuna Terminator feature. | ||
In this example, we utilize the Optuna Terminator for hyperparameter | ||
optimization on a lightgbm using the wine dataset. | ||
The Terminator automatically stops the optimization process based | ||
on the potential for further improvement. | ||
To run this example: | ||
$ python lightgbm_terminator.py | ||
""" | ||
|
||
import optuna | ||
from optuna.terminator.callback import TerminatorCallback | ||
from optuna.terminator.erroreval import report_cross_validation_scores | ||
|
||
from sklearn.datasets import load_wine | ||
import lightgbm as lgb | ||
from sklearn.model_selection import cross_val_score | ||
from sklearn.model_selection import KFold | ||
|
||
|
||
def objective(trial): | ||
X, y = load_wine(return_X_y=True) | ||
|
||
params = { | ||
"objective": "multiclass", | ||
"num_class": 3, | ||
"verbosity": -1, | ||
"boosting_type": "gbdt", | ||
"lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True), | ||
"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True), | ||
"num_leaves": trial.suggest_int("num_leaves", 2, 256), | ||
"feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0), | ||
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0), | ||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7), | ||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100), | ||
} | ||
|
||
clf = lgb.LGBMClassifier(**params) | ||
|
||
scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True)) | ||
report_cross_validation_scores(trial, scores) | ||
|
||
return scores.mean() | ||
|
||
|
||
if __name__ == "__main__": | ||
study = optuna.create_study(direction="maximize") | ||
study.optimize(objective, n_trials=50, callbacks=[TerminatorCallback()]) | ||
|
||
print(f"The number of trials: {len(study.trials)}") | ||
print(f"Best value: {study.best_value} (params: {study.best_params})") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ lightgbm>=3.3.0 | |
numpy | ||
optuna | ||
scikit-learn | ||
botorch |