Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] remove 'fobj' in favor of passing custom objective function in params #5052

Merged
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a70e779
feat: support custom metrics in params
TremaMiguel Mar 3, 2022
9c41c6b
feat: support objective in params
TremaMiguel Mar 3, 2022
8ba0b72
test: custom objective and metric
TremaMiguel Mar 4, 2022
3e21d9f
Merge branch 'microsoft:master' into feat/custom_objective_metric_in_…
TremaMiguel Mar 4, 2022
055ab28
fix: imports are incorrectly sorted
TremaMiguel Mar 4, 2022
7979173
Merge branch 'feat/custom_objective_metric_in_params' of github.com:T…
TremaMiguel Mar 4, 2022
c858d61
feat: convert eval metrics str and set to list
TremaMiguel Mar 4, 2022
704d831
feat: convert single callable eval_metric to list
TremaMiguel Mar 10, 2022
86a1861
test: single callable objective in params
TremaMiguel Mar 12, 2022
8d2565b
feat: callable fobj in basic cv function
TremaMiguel Mar 15, 2022
56705fc
Merge branch 'feat/custom_objective_metric_in_params' of github.com:T…
TremaMiguel Mar 15, 2022
a0fb372
test: cv support objective callable
TremaMiguel Mar 15, 2022
bc8cd18
Merge branch 'microsoft:master' into feat/custom_objective_metric_in_…
TremaMiguel Mar 15, 2022
b5d514b
fix: assert in cv_res
TremaMiguel Mar 15, 2022
2e20ff5
docs: objective callable in params
TremaMiguel Mar 15, 2022
6411dee
recover test_boost_from_average_with_single_leaf_trees
TremaMiguel Mar 15, 2022
2066513
linters fail
TremaMiguel Mar 15, 2022
6c0b2d9
remove metrics helper functions
TremaMiguel Mar 19, 2022
d43c879
feat: choose objective through _choose_param_values
TremaMiguel Mar 19, 2022
d281cd5
test: test objective through _choose_param_values
TremaMiguel Mar 19, 2022
259eecb
test: test objective is callabe in train
TremaMiguel Mar 19, 2022
97d8ab7
test: parametrize choose_param_value with objective aliases
TremaMiguel Mar 19, 2022
4fddac1
test: cv booster metric is none
TremaMiguel Mar 19, 2022
bf1d347
Merge branch 'microsoft:master' into feat/custom_objective_metric_in_…
TremaMiguel Mar 19, 2022
2874390
fix: if string and callable choose callable
TremaMiguel Mar 19, 2022
6a09e5e
test train uses custom objective metrics
TremaMiguel Mar 21, 2022
457b455
test: cv uses custom objective metrics
TremaMiguel Mar 21, 2022
ccab320
refactor: remove fobj parameter in train and cv
TremaMiguel Mar 21, 2022
0730dfb
refactor: objective through params in sklearn API
TremaMiguel Mar 26, 2022
5b68182
Merge branch 'microsoft:master' into feat/custom_objective_metric_in_…
TremaMiguel Mar 26, 2022
9444018
custom objective function in advanced_example
TremaMiguel Mar 26, 2022
a483078
fix whitespackes lint
TremaMiguel Mar 26, 2022
3a22ab3
objective is none not a particular case for predict method
TremaMiguel Mar 27, 2022
b272740
replace scipy.expit with custom implementation
TremaMiguel Mar 28, 2022
3a69df9
test: set num_boost_round value to 20
TremaMiguel Apr 2, 2022
be96bcc
fix: custom objective default_value is none
TremaMiguel Apr 2, 2022
d10d36e
refactor: remove self._fobj
TremaMiguel Apr 2, 2022
5722b73
custom_objective default value is None
TremaMiguel Apr 2, 2022
ff4f465
Merge branch 'microsoft:master' into feat/custom_objective_metric_in_…
TremaMiguel Apr 3, 2022
8dc017a
refactor: variables name reference dummy_obj
TremaMiguel Apr 11, 2022
b67fc88
Merge branch 'master' into feat/custom_objective_metric_in_params
TremaMiguel Apr 11, 2022
8d4740a
linter errors
TremaMiguel Apr 11, 2022
29aff3f
fix: process objective parameter when calling predict
TremaMiguel Apr 11, 2022
c26c2a9
linter errors
TremaMiguel Apr 11, 2022
36ca839
fix: objective is None during predict call
TremaMiguel Apr 17, 2022
3aeea17
Merge branch 'master' into feat/custom_objective_metric_in_params
TremaMiguel Apr 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions examples/python-guide/advanced_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding: utf-8
import copy
import json
import pickle
from pathlib import Path
Expand Down Expand Up @@ -159,11 +160,14 @@ def binary_error(preds, train_data):
return 'error', np.mean(labels != (preds > 0.5)), False


gbm = lgb.train(params,
# Pass custom objective function through params
params_custom_obj = copy.deepcopy(params)
params_custom_obj['objective'] = loglikelihood

gbm = lgb.train(params_custom_obj,
lgb_train,
num_boost_round=10,
init_model=gbm,
fobj=loglikelihood,
feval=binary_error,
valid_sets=lgb_eval)

Expand All @@ -183,11 +187,14 @@ def accuracy(preds, train_data):
return 'accuracy', np.mean(labels == (preds > 0.5)), True


gbm = lgb.train(params,
# Pass custom objective function through params
params_custom_obj = copy.deepcopy(params)
params_custom_obj['objective'] = loglikelihood

gbm = lgb.train(params_custom_obj,
lgb_train,
num_boost_round=10,
init_model=gbm,
fobj=loglikelihood,
feval=[binary_error, accuracy],
valid_sets=lgb_eval)

Expand Down
6 changes: 3 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,7 +3185,7 @@ def eval(self, data, name, feval=None):
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
A ``Dataset`` to evaluate.
Expand Down Expand Up @@ -3231,7 +3231,7 @@ def eval_train(self, feval=None):
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
The training dataset.
Expand Down Expand Up @@ -3262,7 +3262,7 @@ def eval_valid(self, feval=None):
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
The validation dataset.
Expand Down
128 changes: 67 additions & 61 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold

_LGBM_CustomObjectiveFunction = Callable[
[np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray]
]
_LGBM_CustomMetricFunction = Callable[
[np.ndarray, Dataset],
Tuple[str, float, bool]
Expand All @@ -28,7 +24,6 @@ def train(
num_boost_round: int = 100,
valid_sets: Optional[List[Dataset]] = None,
valid_names: Optional[List[str]] = None,
fobj: Optional[_LGBM_CustomObjectiveFunction] = None,
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[List[str], str] = 'auto',
Expand All @@ -41,7 +36,8 @@ def train(
Parameters
----------
params : dict
Parameters for training.
Parameters for training. Values passed through ``params`` take precedence over those
supplied via arguments.
train_set : Dataset
Data to be trained on.
num_boost_round : int, optional (default=100)
Expand All @@ -50,27 +46,6 @@ def train(
List of data to be evaluated on during training.
valid_names : list of str, or None, optional (default=None)
Names of ``valid_sets``.
fobj : callable or None, optional (default=None)
Customized objective function.
Should accept two parameters: preds, train_data,
and return (grad, hess).

preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.

For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.

feval : callable, list of callable, or None, optional (default=None)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Customized evaluation function.
Each evaluation function should accept two parameters: preds, eval_data,
Expand All @@ -79,7 +54,7 @@ def train(
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
A ``Dataset`` to evaluate.
Expand Down Expand Up @@ -118,17 +93,43 @@ def train(
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.

Note
----
A custom objective function can be provided for the ``objective`` parameter.
It should accept two parameters: preds, train_data and return (grad, hess).

preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.

For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.

Returns
-------
booster : Booster
The trained Booster model.
"""
# create predictor first
params = copy.deepcopy(params)
if fobj is not None:
for obj_alias in _ConfigAliases.get("objective"):
params.pop(obj_alias, None)
params['objective'] = 'none'
params = _choose_param_value(
main_param_name='objective',
params=params,
default_value=None
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
)
fobj = None
if callable(params["objective"]):
fobj = params["objective"]
params["objective"] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
num_boost_round = params.pop(alias)
Expand Down Expand Up @@ -374,7 +375,7 @@ def _agg_cv_result(raw_results):

def cv(params, train_set, num_boost_round=100,
folds=None, nfold=5, stratified=True, shuffle=True,
metrics=None, fobj=None, feval=None, init_model=None,
metrics=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
fpreproc=None, seed=0, callbacks=None, eval_train_metric=False,
return_cvbooster=False):
Expand All @@ -383,7 +384,8 @@ def cv(params, train_set, num_boost_round=100,
Parameters
----------
params : dict
Parameters for Booster.
Parameters for training. Values passed through ``params`` take precedence over those
supplied via arguments.
train_set : Dataset
Data to be trained on.
num_boost_round : int, optional (default=100)
Expand All @@ -403,27 +405,6 @@ def cv(params, train_set, num_boost_round=100,
metrics : str, list of str, or None, optional (default=None)
Evaluation metrics to be monitored while CV.
If not None, the metric in ``params`` will be overridden.
fobj : callable or None, optional (default=None)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Customized objective function.
Should accept two parameters: preds, train_data,
and return (grad, hess).

preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.

For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.

feval : callable, list of callable, or None, optional (default=None)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Customized evaluation function.
Each evaluation function should accept two parameters: preds, eval_data,
Expand All @@ -432,7 +413,7 @@ def cv(params, train_set, num_boost_round=100,
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
If ``fobj`` is specified, predicted values are returned before any transformation,
If custom objective function is used, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
eval_data : Dataset
A ``Dataset`` to evaluate.
Expand Down Expand Up @@ -474,6 +455,27 @@ def cv(params, train_set, num_boost_round=100,
return_cvbooster : bool, optional (default=False)
Whether to return Booster models trained on each fold through ``CVBooster``.

Note
----
A custom objective function can be provided for the ``objective`` parameter.
It should accept two parameters: preds, train_data and return (grad, hess).

preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.

For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
and grad and hess should be returned in the same format.

Returns
-------
eval_hist : dict
Expand All @@ -486,12 +488,16 @@ def cv(params, train_set, num_boost_round=100,
"""
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")

params = copy.deepcopy(params)
if fobj is not None:
for obj_alias in _ConfigAliases.get("objective"):
params.pop(obj_alias, None)
params['objective'] = 'none'
params = _choose_param_value(
main_param_name='objective',
params=params,
default_value=None
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
)
fobj = None
if callable(params["objective"]):
fobj = params["objective"]
params["objective"] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
_log_warning(f"Found '{alias}' in params. Will use it instead of 'num_boost_round' argument")
Expand Down
8 changes: 3 additions & 5 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,10 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
raise ValueError("Unknown LGBMModel type.")
if callable(self._objective):
if stage == "fit":
self._fobj = _ObjectiveFunctionWrapper(self._objective)
params['objective'] = 'None' # objective = nullptr for unknown objective
params['objective'] = _ObjectiveFunctionWrapper(self._objective)
else:
params['objective'] = 'None'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider checking this my comment about poping the values instead of using None here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StrikerRUS If we pop objective the following error in test_sklearn.test_multiclass_custom_objective raises

lightgbm.basic.LightGBMError: Number of classes must be 1 for non-multiclass training

I think the objective task is lost somewhere when predicting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass the value of objective to params based on the class instance, that is,

if callable(self._objective):
    if stage == "fit":
        params['objective'] = _ObjectiveFunctionWrapper(self._objective)
    else:
        # params['objective'] = 'None'
        if isinstance(self, LGBMRegressor):
            params['objective'] =  "regression"
        elif isinstance(self, LGBMClassifier):
            if self._n_classes > 2:
                params['objective'] = "multiclass"
            else:
                params['objective'] = "binary"
        elif isinstance(self, LGBMRanker):
            params['objective'] = "lambdarank"
        else:
            raise ValueError("Unknown LGBMModel type.")

but is more clean just to write

params['objective'] = 'None'

what's the special meaning of None ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see now.
Seems that LightGBM checks the consistency for objective and number of classes even during predicting.

void Config::CheckParamConflict() {

Full logs:

==================================================================================================== FAILURES ====================================================================================================
________________________________________________________________________________________ test_multiclass_custom_objective ________________________________________________________________________________________

    def test_multiclass_custom_objective():
        centers = [[-4, -4], [4, 4], [-4, 4]]
        X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
        params = {'n_estimators': 10, 'num_leaves': 7}
        builtin_obj_model = lgb.LGBMClassifier(**params)
        builtin_obj_model.fit(X, y)
        builtin_obj_preds = builtin_obj_model.predict_proba(X)

        custom_obj_model = lgb.LGBMClassifier(objective=sklearn_multiclass_custom_objective, **params)
        custom_obj_model.fit(X, y)
>       custom_obj_preds = softmax(custom_obj_model.predict(X, raw_score=True))

test_sklearn.py:1299:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
d:\miniconda3\lib\site-packages\lightgbm\sklearn.py:1050: in predict
    result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
d:\miniconda3\lib\site-packages\lightgbm\sklearn.py:1063: in predict_proba
    result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs)
d:\miniconda3\lib\site-packages\lightgbm\sklearn.py:813: in predict
    return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
d:\miniconda3\lib\site-packages\lightgbm\basic.py:3538: in predict
    return predictor.predict(data, start_iteration, num_iteration,
d:\miniconda3\lib\site-packages\lightgbm\basic.py:813: in predict
    preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type)
d:\miniconda3\lib\site-packages\lightgbm\basic.py:903: in __pred_for_np2d
    return inner_predict(mat, start_iteration, num_iteration, predict_type)
d:\miniconda3\lib\site-packages\lightgbm\basic.py:873: in inner_predict
    _safe_call(_LIB.LGBM_BoosterPredictForMat(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ret = -1

    def _safe_call(ret: int) -> None:
        """Check the return value from C API call.

        Parameters
        ----------
        ret : int
            The return value from C API calls.
        """
        if ret != 0:
>           raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))
E           lightgbm.basic.LightGBMError: Number of classes must be 1 for non-multiclass training

d:\miniconda3\lib\site-packages\lightgbm\basic.py:142: LightGBMError
---------------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------------
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000075 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 510
[LightGBM] [Info] Number of data points in the train set: 1000, number of used features: 2
[LightGBM] [Info] Start training from score -1.096614
[LightGBM] [Info] Start training from score -1.099613
[LightGBM] [Info] Start training from score -1.099613
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] Using self-defined objective function
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000071 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 510
[LightGBM] [Info] Number of data points in the train set: 1000, number of used features: 2
[LightGBM] [Warning] Using self-defined objective function
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
---------------------------------------------------------------------------------------------- Captured stderr call ----------------------------------------------------------------------------------------------
[LightGBM] [Fatal] Number of classes must be 1 for non-multiclass training

Let's then pass 'None' for the custom object during predict phase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

else:
if stage == "fit":
self._fobj = None
params['objective'] = self._objective

params.pop('importance_type', None)
Expand Down Expand Up @@ -756,7 +755,6 @@ def _get_meta_data(collection, name, i):
num_boost_round=self.n_estimators,
valid_sets=valid_sets,
valid_names=eval_names,
fobj=self._fobj,
feval=eval_metrics_callable,
init_model=init_model,
feature_name=feature_name,
Expand Down
39 changes: 39 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,45 @@ def test_choose_param_value():
assert original_params == expected_params


@pytest.mark.parametrize("objective_alias", lgb.basic._ConfigAliases.get("objective"))
def test_choose_param_value_objective(objective_alias):
def dummy_obj(preds, train_data):
return np.ones(preds.shape), np.ones(preds.shape)

def mse_obj(y_pred, dtrain):
y_true = dtrain.get_label()
grad = (y_pred - y_true)
hess = np.ones(len(grad))
return grad, hess
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved

# If callable is found in objective
params = {objective_alias: dummy_obj}
params = lgb.basic._choose_param_value(
main_param_name="objective",
params=params,
default_value=None
)
assert params['objective'] == dummy_obj

# Value in params should be preferred to the default_value passed from keyword arguments
params = {objective_alias: dummy_obj}
params = lgb.basic._choose_param_value(
main_param_name="objective",
params=params,
default_value=mse_obj
)
assert params['objective'] == dummy_obj

TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
# None of objective or its aliases in params, but default_value is callable.
params = {}
params = lgb.basic._choose_param_value(
main_param_name="objective",
params=params,
default_value=mse_obj
)
assert params['objective'] == mse_obj


@pytest.mark.parametrize('collection', ['1d_np', '2d_np', 'pd_float', 'pd_str', '1d_list', '2d_list'])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_list_to_1d_numpy(collection, dtype):
Expand Down
Loading