Skip to content

Commit

Permalink
[python-package] remove 'fobj' in favor of passing custom objective f…
Browse files Browse the repository at this point in the history
…unction in params (fixes #3244) (#5052)

* feat: support custom metrics in params

* feat: support objective in params

* test: custom objective and metric

* fix: imports are incorrectly sorted

* feat: convert eval metrics str and set to list

* feat: convert single callable eval_metric to list

* test: single callable objective in params

Signed-off-by: Miguel Trejo <[email protected]>

* feat: callable fobj in basic cv function

Signed-off-by: Miguel Trejo <[email protected]>

* test: cv support objective callable

Signed-off-by: Miguel Trejo <[email protected]>

* fix: assert in cv_res

Signed-off-by: Miguel Trejo <[email protected]>

* docs: objective callable in params

Signed-off-by: Miguel Trejo <[email protected]>

* recover test_boost_from_average_with_single_leaf_trees

Signed-off-by: Miguel Trejo <[email protected]>

* linters fail

Signed-off-by: Miguel Trejo <[email protected]>

* remove metrics helper functions

Signed-off-by: Miguel Trejo <[email protected]>

* feat: choose objective through _choose_param_values

Signed-off-by: Miguel Trejo <[email protected]>

* test: test objective through _choose_param_values

Signed-off-by: Miguel Trejo <[email protected]>

* test: test objective is callabe in train

Signed-off-by: Miguel Trejo <[email protected]>

* test: parametrize choose_param_value with objective aliases

Signed-off-by: Miguel Trejo <[email protected]>

* test: cv booster metric is none

Signed-off-by: Miguel Trejo <[email protected]>

* fix: if string and callable choose callable

Signed-off-by: Miguel Trejo <[email protected]>

* test train uses custom objective metrics

Signed-off-by: Miguel Trejo <[email protected]>

* test: cv uses custom objective metrics

Signed-off-by: Miguel Trejo <[email protected]>

* refactor: remove fobj parameter in train and cv

Signed-off-by: Miguel Trejo <[email protected]>

* refactor: objective through params in sklearn API

Signed-off-by: Miguel Trejo <[email protected]>

* custom objective function in advanced_example

Signed-off-by: Miguel Trejo <[email protected]>

* fix whitespackes lint

* objective is none not a particular case for predict method

Signed-off-by: Miguel Trejo <[email protected]>

* replace scipy.expit with custom implementation

Signed-off-by: Miguel Trejo <[email protected]>

* test: set num_boost_round value to 20

Signed-off-by: Miguel Trejo <[email protected]>

* fix: custom objective default_value is none

Signed-off-by: Miguel Trejo <[email protected]>

* refactor: remove self._fobj

Signed-off-by: Miguel Trejo <[email protected]>

* custom_objective default value is None

Signed-off-by: Miguel Trejo <[email protected]>

* refactor: variables name reference dummy_obj

Signed-off-by: Miguel Trejo <[email protected]>

* linter errors

* fix: process objective parameter when calling predict

Signed-off-by: Miguel Trejo <[email protected]>

* linter errors

* fix: objective is None during predict call

Signed-off-by: Miguel Trejo <[email protected]>
  • Loading branch information
TremaMiguel authored Apr 22, 2022
1 parent fc0c8fd commit 416ecd5
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 115 deletions.
15 changes: 11 additions & 4 deletions examples/python-guide/advanced_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding: utf-8
import copy
import json
import pickle
from pathlib import Path
Expand Down Expand Up @@ -159,11 +160,14 @@ def binary_error(preds, train_data):
return 'error', np.mean(labels != (preds > 0.5)), False


gbm = lgb.train(params,
# Pass custom objective function through params
params_custom_obj = copy.deepcopy(params)
params_custom_obj['objective'] = loglikelihood

gbm = lgb.train(params_custom_obj,
lgb_train,
num_boost_round=10,
init_model=gbm,
fobj=loglikelihood,
feval=binary_error,
valid_sets=lgb_eval)

Expand All @@ -183,11 +187,14 @@ def accuracy(preds, train_data):
return 'accuracy', np.mean(labels == (preds > 0.5)), True


gbm = lgb.train(params,
# Pass custom objective function through params
params_custom_obj = copy.deepcopy(params)
params_custom_obj['objective'] = loglikelihood

gbm = lgb.train(params_custom_obj,
lgb_train,
num_boost_round=10,
init_model=gbm,
fobj=loglikelihood,
feval=[binary_error, accuracy],
valid_sets=lgb_eval)

Expand Down
6 changes: 3 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,7 +3185,7 @@ def eval(self, data, name, feval=None):
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
A ``Dataset`` to evaluate.
Expand Down Expand Up @@ -3231,7 +3231,7 @@ def eval_train(self, feval=None):
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
The training dataset.
Expand Down Expand Up @@ -3262,7 +3262,7 @@ def eval_valid(self, feval=None):
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
The validation dataset.
Expand Down
128 changes: 67 additions & 61 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold

_LGBM_CustomObjectiveFunction = Callable[
[np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray]
]
_LGBM_CustomMetricFunction = Callable[
[np.ndarray, Dataset],
Tuple[str, float, bool]
Expand All @@ -28,7 +24,6 @@ def train(
num_boost_round: int = 100,
valid_sets: Optional[List[Dataset]] = None,
valid_names: Optional[List[str]] = None,
fobj: Optional[_LGBM_CustomObjectiveFunction] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[List[str], str] = 'auto',
Expand All @@ -41,7 +36,8 @@ def train(
Parameters
----------
params : dict
Parameters for training.
Parameters for training. Values passed through ``params`` take precedence over those
supplied via arguments.
train_set : Dataset
Data to be trained on.
num_boost_round : int, optional (default=100)
Expand All @@ -50,27 +46,6 @@ def train(
List of data to be evaluated on during training.
valid_names : list of str, or None, optional (default=None)
Names of ``valid_sets``.
fobj : callable or None, optional (default=None)
Customized objective function.
Should accept two parameters: preds, train_data,
and return (grad, hess).
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.
feval : callable, list of callable, or None, optional (default=None)
Customized evaluation function.
Each evaluation function should accept two parameters: preds, eval_data,
Expand All @@ -79,7 +54,7 @@ def train(
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
A ``Dataset`` to evaluate.
Expand Down Expand Up @@ -118,17 +93,43 @@ def train(
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
Note
----
A custom objective function can be provided for the ``objective`` parameter.
It should accept two parameters: preds, train_data and return (grad, hess).
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.
Returns
-------
booster : Booster
The trained Booster model.
"""
# create predictor first
params = copy.deepcopy(params)
if fobj is not None:
for obj_alias in _ConfigAliases.get("objective"):
params.pop(obj_alias, None)
params['objective'] = 'none'
params = _choose_param_value(
main_param_name='objective',
params=params,
default_value=None
)
fobj = None
if callable(params["objective"]):
fobj = params["objective"]
params["objective"] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
num_boost_round = params.pop(alias)
Expand Down Expand Up @@ -374,7 +375,7 @@ def _agg_cv_result(raw_results):

def cv(params, train_set, num_boost_round=100,
folds=None, nfold=5, stratified=True, shuffle=True,
metrics=None, fobj=None, feval=None, init_model=None,
metrics=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
fpreproc=None, seed=0, callbacks=None, eval_train_metric=False,
return_cvbooster=False):
Expand All @@ -383,7 +384,8 @@ def cv(params, train_set, num_boost_round=100,
Parameters
----------
params : dict
Parameters for Booster.
Parameters for training. Values passed through ``params`` take precedence over those
supplied via arguments.
train_set : Dataset
Data to be trained on.
num_boost_round : int, optional (default=100)
Expand All @@ -403,27 +405,6 @@ def cv(params, train_set, num_boost_round=100,
metrics : str, list of str, or None, optional (default=None)
Evaluation metrics to be monitored while CV.
If not None, the metric in ``params`` will be overridden.
fobj : callable or None, optional (default=None)
Customized objective function.
Should accept two parameters: preds, train_data,
and return (grad, hess).
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.
feval : callable, list of callable, or None, optional (default=None)
Customized evaluation function.
Each evaluation function should accept two parameters: preds, eval_data,
Expand All @@ -432,7 +413,7 @@ def cv(params, train_set, num_boost_round=100,
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
A ``Dataset`` to evaluate.
Expand Down Expand Up @@ -474,6 +455,27 @@ def cv(params, train_set, num_boost_round=100,
return_cvbooster : bool, optional (default=False)
Whether to return Booster models trained on each fold through ``CVBooster``.
Note
----
A custom objective function can be provided for the ``objective`` parameter.
It should accept two parameters: preds, train_data and return (grad, hess).
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.
Returns
-------
eval_hist : dict
Expand All @@ -486,12 +488,16 @@ def cv(params, train_set, num_boost_round=100,
"""
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")

params = copy.deepcopy(params)
if fobj is not None:
for obj_alias in _ConfigAliases.get("objective"):
params.pop(obj_alias, None)
params['objective'] = 'none'
params = _choose_param_value(
main_param_name='objective',
params=params,
default_value=None
)
fobj = None
if callable(params["objective"]):
fobj = params["objective"]
params["objective"] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
_log_warning(f"Found '{alias}' in params. Will use it instead of 'num_boost_round' argument")
Expand Down
8 changes: 3 additions & 5 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,10 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
raise ValueError("Unknown LGBMModel type.")
if callable(self._objective):
if stage == "fit":
self._fobj = _ObjectiveFunctionWrapper(self._objective)
params['objective'] = 'None' # objective = nullptr for unknown objective
params['objective'] = _ObjectiveFunctionWrapper(self._objective)
else:
params['objective'] = 'None'
else:
if stage == "fit":
self._fobj = None
params['objective'] = self._objective

params.pop('importance_type', None)
Expand Down Expand Up @@ -756,7 +755,6 @@ def _get_meta_data(collection, name, i):
num_boost_round=self.n_estimators,
valid_sets=valid_sets,
valid_names=eval_names,
fobj=self._fobj,
feval=eval_metrics_callable,
init_model=init_model,
feature_name=feature_name,
Expand Down
32 changes: 31 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series

from .utils import load_breast_cancer
from .utils import dummy_obj, load_breast_cancer, mse_obj


def test_basic(tmp_path):
Expand Down Expand Up @@ -513,6 +513,36 @@ def test_choose_param_value():
assert original_params == expected_params


@pytest.mark.parametrize("objective_alias", lgb.basic._ConfigAliases.get("objective"))
def test_choose_param_value_objective(objective_alias):
# If callable is found in objective
params = {objective_alias: dummy_obj}
params = lgb.basic._choose_param_value(
main_param_name="objective",
params=params,
default_value=None
)
assert params['objective'] == dummy_obj

# Value in params should be preferred to the default_value passed from keyword arguments
params = {objective_alias: dummy_obj}
params = lgb.basic._choose_param_value(
main_param_name="objective",
params=params,
default_value=mse_obj
)
assert params['objective'] == dummy_obj

# None of objective or its aliases in params, but default_value is callable.
params = {}
params = lgb.basic._choose_param_value(
main_param_name="objective",
params=params,
default_value=mse_obj
)
assert params['objective'] == mse_obj


@pytest.mark.parametrize('collection', ['1d_np', '2d_np', 'pd_float', 'pd_str', '1d_list', '2d_list'])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_list_to_1d_numpy(collection, dtype):
Expand Down
Loading

0 comments on commit 416ecd5

Please sign in to comment.