From e0a698a8a96a4654b1926faec1da148320d97515 Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Sat, 13 Jul 2024 23:24:43 +0900 Subject: [PATCH 1/2] replace accuracy with auc in the best score attribute --- lightgbm/lightgbm_integration.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lightgbm/lightgbm_integration.py b/lightgbm/lightgbm_integration.py index 09586593..f75e9d9f 100644 --- a/lightgbm/lightgbm_integration.py +++ b/lightgbm/lightgbm_integration.py @@ -10,12 +10,10 @@ """ -import numpy as np import optuna import lightgbm as lgb import sklearn.datasets -import sklearn.metrics from sklearn.model_selection import train_test_split @@ -45,10 +43,7 @@ def objective(trial): pruning_callback = optuna.integration.LightGBMPruningCallback(trial, "auc") gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback]) - preds = gbm.predict(valid_x) - pred_labels = np.rint(preds) - accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels) - return accuracy + return gbm.best_score["valid_0"]["auc"] if __name__ == "__main__": From a231306456fcc454e837d8a5119ce67b7e4b0244 Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Wed, 24 Jul 2024 12:39:43 +0900 Subject: [PATCH 2/2] Use numpy function to compute AUC score --- lightgbm/lightgbm_integration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightgbm/lightgbm_integration.py b/lightgbm/lightgbm_integration.py index f75e9d9f..886b616a 100644 --- a/lightgbm/lightgbm_integration.py +++ b/lightgbm/lightgbm_integration.py @@ -14,6 +14,7 @@ import lightgbm as lgb import sklearn.datasets +import sklearn.metrics from sklearn.model_selection import train_test_split @@ -43,7 +44,8 @@ def objective(trial): pruning_callback = optuna.integration.LightGBMPruningCallback(trial, "auc") gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback]) - return gbm.best_score["valid_0"]["auc"] + preds = gbm.predict(valid_x) + return sklearn.metrics.roc_auc_score(valid_y, preds) if __name__ == "__main__":