From 1d59a0456669d4f4d087da88f9cc0aa8b963342a Mon Sep 17 00:00:00 2001 From: momijiame Date: Mon, 3 Aug 2020 01:03:32 +0900 Subject: [PATCH] [python] add return_cvbooster flag to cv func and publish _CVBooster (#283,#2105,#1445) (#3204) * [python] add return_cvbooster flag to cv function and rename _CVBooster to make public (#283,#2105) * [python] Reduce expected metric of unit testing * [docs] add the CVBooster to the documentation * [python] reflect the review comments - Add some clarifications to the documentation - Rename CVBooster.append to make private - Decrease iteration rounds of testing to save CI time - Use CVBooster as root member of lgb * [python] add more checks in testing for cv Co-authored-by: Nikita Titov * [python] add docstring for instance attributes of CVBooster Co-authored-by: Nikita Titov * [python] fix docstring Co-authored-by: Nikita Titov Co-authored-by: Nikita Titov --- docs/Python-API.rst | 1 + python-package/lightgbm/__init__.py | 4 +-- python-package/lightgbm/engine.py | 40 ++++++++++++++++----- tests/python_package_test/test_engine.py | 44 ++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/docs/Python-API.rst b/docs/Python-API.rst index de6b1ec6f2b9..60389216bb2b 100644 --- a/docs/Python-API.rst +++ b/docs/Python-API.rst @@ -11,6 +11,7 @@ Data Structure API Dataset Booster + CVBooster Training API ------------ diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 390a6994a7a2..584c23d7c8c5 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -8,7 +8,7 @@ from .basic import Booster, Dataset from .callback import (early_stopping, print_evaluation, record_evaluation, reset_parameter) -from .engine import cv, train +from .engine import cv, train, CVBooster import os @@ -29,7 +29,7 @@ with open(os.path.join(dir_path, 'VERSION.txt')) as version_file: __version__ = version_file.read().strip() -__all__ = ['Dataset', 'Booster', +__all__ = ['Dataset', 'Booster', 'CVBooster', 'train', 'cv', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index d80cb7a47cf3..6306a4c00f80 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -276,19 +276,35 @@ def train(params, train_set, num_boost_round=100, return booster -class _CVBooster(object): - """Auxiliary data struct to hold all boosters of CV.""" +class CVBooster(object): + """CVBooster in LightGBM. + + Auxiliary data structure to hold and redirect all boosters of ``cv`` function. + This class has the same methods as Booster class. + All method calls are actually performed for underlying Boosters and then all returned results are returned in a list. + + Attributes + ---------- + boosters : list of Booster + The list of underlying fitted models. + best_iteration : int + The best iteration of fitted model. + """ def __init__(self): + """Initialize the CVBooster. + + Generally, no need to instantiate manually. + """ self.boosters = [] self.best_iteration = -1 - def append(self, booster): - """Add a booster to _CVBooster.""" + def _append(self, booster): + """Add a booster to CVBooster.""" self.boosters.append(booster) def __getattr__(self, name): - """Redirect methods call of _CVBooster.""" + """Redirect methods call of CVBooster.""" def handler_function(*args, **kwargs): """Call methods with each booster, and concatenate their results.""" ret = [] @@ -341,7 +357,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi train_id = [np.concatenate([test_id[i] for i in range_(nfold) if k != i]) for k in range_(nfold)] folds = zip_(train_id, test_id) - ret = _CVBooster() + ret = CVBooster() for train_idx, test_idx in folds: train_set = full_data.subset(sorted(train_idx)) valid_set = full_data.subset(sorted(test_idx)) @@ -354,7 +370,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi if eval_train_metric: cvbooster.add_valid(train_set, 'train') cvbooster.add_valid(valid_set, 'valid') - ret.append(cvbooster) + ret._append(cvbooster) return ret @@ -380,7 +396,8 @@ def cv(params, train_set, num_boost_round=100, feature_name='auto', categorical_feature='auto', early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, - callbacks=None, eval_train_metric=False): + callbacks=None, eval_train_metric=False, + return_cvbooster=False): """Perform the cross-validation with given paramaters. Parameters @@ -486,6 +503,8 @@ def cv(params, train_set, num_boost_round=100, eval_train_metric : bool, optional (default=False) Whether to display the train metric in progress. The score of the metric is calculated again after each training step, so there is some impact on performance. + return_cvbooster : bool, optional (default=False) + Whether to return Booster models trained on each fold through ``CVBooster``. Returns ------- @@ -495,6 +514,7 @@ def cv(params, train_set, num_boost_round=100, {'metric1-mean': [values], 'metric1-stdv': [values], 'metric2-mean': [values], 'metric2-stdv': [values], ...}. + If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key. """ if not isinstance(train_set, Dataset): raise TypeError("Training only accepts Dataset object") @@ -586,4 +606,8 @@ def cv(params, train_set, num_boost_round=100, for k in results: results[k] = results[k][:cvfolds.best_iteration] break + + if return_cvbooster: + results['cvbooster'] = cvfolds + return dict(results) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index a26bd7a09449..2cfdf67fe94c 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -735,6 +735,50 @@ def test_cv(self): verbose_eval=False) np.testing.assert_allclose(cv_res_lambda['ndcg@3-mean'], cv_res_lambda_obj['ndcg@3-mean']) + def test_cvbooster(self): + X, y = load_breast_cancer(True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + params = { + 'objective': 'binary', + 'metric': 'binary_logloss', + 'verbose': -1, + } + lgb_train = lgb.Dataset(X_train, y_train) + # with early stopping + cv_res = lgb.cv(params, lgb_train, + num_boost_round=25, + early_stopping_rounds=5, + verbose_eval=False, + nfold=3, + return_cvbooster=True) + self.assertIn('cvbooster', cv_res) + cvb = cv_res['cvbooster'] + self.assertIsInstance(cvb, lgb.CVBooster) + self.assertIsInstance(cvb.boosters, list) + self.assertEqual(len(cvb.boosters), 3) + self.assertTrue(all(isinstance(bst, lgb.Booster) for bst in cvb.boosters)) + self.assertGreater(cvb.best_iteration, 0) + # predict by each fold booster + preds = cvb.predict(X_test, num_iteration=cvb.best_iteration) + self.assertIsInstance(preds, list) + self.assertEqual(len(preds), 3) + # fold averaging + avg_pred = np.mean(preds, axis=0) + ret = log_loss(y_test, avg_pred) + self.assertLess(ret, 0.13) + # without early stopping + cv_res = lgb.cv(params, lgb_train, + num_boost_round=20, + verbose_eval=False, + nfold=3, + return_cvbooster=True) + cvb = cv_res['cvbooster'] + self.assertEqual(cvb.best_iteration, -1) + preds = cvb.predict(X_test) + avg_pred = np.mean(preds, axis=0) + ret = log_loss(y_test, avg_pred) + self.assertLess(ret, 0.15) + def test_feature_name(self): X_train, y_train = load_boston(True) params = {'verbose': -1}