Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected Behavior for Early Stopping with Custom Metric #2371

Closed
pford221 opened this issue Sep 1, 2019 · 18 comments
Closed

Unexpected Behavior for Early Stopping with Custom Metric #2371

pford221 opened this issue Sep 1, 2019 · 18 comments

Comments

@pford221
Copy link
Contributor

pford221 commented Sep 1, 2019

Hello -

When passing a custom metric to feval argument in Booster.train(), I'm getting what I believe to be unexpected behavior. Or what I presume to be undesirable. The Booster is not returning the best iteration as determined by the validation set as can be seen with the Reproducible example and Output below.

Environment:

lightgbm: 2.2.3
os: Linux CPU
ami: 4.14.138-114.102.amzn2.x86_64

Reproducible example

import numpy as np
import lightgbm as lgb

np.random.seed(90210)
N = 100
x = np.random.uniform(size = (N, 7))

lp = (x[:,0]*1.34 + x[:,1]*x[:,2]*0.89 + x[:,2]*0.05 + x[:,2]*x[:,3]*1.34 + x[:,3]*0.31
      + np.random.normal(loc = 2, scale = 0.5, size = 100))

y = np.exp(lp - lp.mean()) / (1 + np.exp(lp - lp.mean()))
y = (y > 0.5).astype(np.uint8)

in_trn = np.random.binomial(1, p = 0.6, size = 100).astype(np.bool)
x_trn = x[ in_trn]
x_val = x[~in_trn]
y_trn = y[ in_trn]
y_val = y[~in_trn]

l_trn = lgb.Dataset(x_trn, y_trn)
l_val = lgb.Dataset(x_val, y_val)

def feval(prd, dmat):
    act = dmat.get_label()
    cls = (prd > 0.5).astype(np.uint8)
    tp = (cls * act).sum()
    fp = ((cls == 1) & (act == 0)).sum()
    savings = (100 * tp -75 * fp) / len(act)
    return 'savings', savings, True

params = {
        'learning_rate':0.05,
        'reg_lambda': 1,
        'feature_fraction': 0.75,
        'bagging_fraction': 0.7,
        'bagging_freq': 1,
        'boosting_type': 'gbdt',
        'objective': 'binary',
        'reg_alpha': 2,
        'num_leaves':  31,
        'min_data_in_leaf': 1,
        'feature_fraction_seed': 201,
        'bagging_seed': 427,
        'metric': 'None',
        }

mod = lgb.train(params, l_trn, num_boost_round = 1000, valid_sets = [l_trn,l_val],
                        valid_names = ['trn','val'], early_stopping_rounds = 100,
                        verbose_eval = 1, feval = feval)

Output

[1] trn's savings: 16.6667 val's savings: 5.40541
Training until validation scores don't improve for 100 rounds.
[2] trn's savings: 30.5556 val's savings: 21.6216
[3] trn's savings: 34.127 val's savings: 14.8649
[4] trn's savings: 36.9048 val's savings: 14.8649
[5] trn's savings: 36.1111 val's savings: 16.8919
[6] trn's savings: 34.5238 val's savings: 18.2432
[7] trn's savings: 35.7143 val's savings: 18.2432
[8] trn's savings: 34.5238 val's savings: 20.2703
[9] trn's savings: 32.1429 val's savings: 18.2432
[10] trn's savings: 36.9048 val's savings: 20.2703
[11] trn's savings: 40.0794 val's savings: 20.2703
[12] trn's savings: 44.8413 val's savings: 22.2973
[13] trn's savings: 44.8413 val's savings: 19.5946
[14] trn's savings: 43.254 val's savings: 19.5946
[15] trn's savings: 44.8413 val's savings: 21.6216
[16] trn's savings: 44.8413 val's savings: 21.6216
[17] trn's savings: 43.6508 val's savings: 21.6216
[18] trn's savings: 42.0635 val's savings: 21.6216
[19] trn's savings: 44.8413 val's savings: 19.5946
[20] trn's savings: 44.8413 val's savings: 21.6216
[21] trn's savings: 46.0317 val's savings: 21.6216
[22] trn's savings: 46.0317 val's savings: 21.6216

....
....
[110] trn's savings: 44.4444 val's savings: 23.6486
[111] trn's savings: 43.254 val's savings: 23.6486
[112] trn's savings: 43.254 val's savings: 23.6486
[113] trn's savings: 43.254 val's savings: 23.6486
[114] trn's savings: 43.254 val's savings: 23.6486
[115] trn's savings: 43.254 val's savings: 23.6486
[116] trn's savings: 43.254 val's savings: 23.6486
[117] trn's savings: 43.254 val's savings: 23.6486
[118] trn's savings: 43.254 val's savings: 23.6486
[119] trn's savings: 43.254 val's savings: 23.6486
[120] trn's savings: 43.254 val's savings: 23.6486
[121] trn's savings: 43.254 val's savings: 23.6486
Early stopping, best iteration is:
[21] trn's savings: 46.0317 val's savings: 21.6216

Other Notes
When I remove the training set from valid_sets and update valid_names accordingly, I do indeed get the correct iteration returned.

mod = lgb.train(params, l_trn, num_boost_round = 1000, valid_sets = l_val,
                        valid_names = ['val'], early_stopping_rounds = 100,
                        verbose_eval = 1, feval = feval)

[1] val's savings: 5.40541
Training until validation scores don't improve for 100 rounds.
[2] val's savings: 21.6216
[3] val's savings: 14.8649
[4] val's savings: 14.8649
[5] val's savings: 16.8919
[6] val's savings: 18.2432
[7] val's savings: 18.2432
[8] val's savings: 20.2703
[9] val's savings: 18.2432
[10] val's savings: 20.2703
[11] val's savings: 20.2703
...
...
[72] val's savings: 21.6216
[73] val's savings: 21.6216
[74] val's savings: 21.6216
[75] val's savings: 21.6216
[76] val's savings: 21.6216
[77] val's savings: 21.6216
[78] val's savings: 21.6216
[79] val's savings: 23.6486
[80] val's savings: 23.6486
[81] val's savings: 23.6486
[82] val's savings: 23.6486
[83] val's savings: 23.6486
[84] val's savings: 23.6486
...
...
[172] val's savings: 23.6486
[173] val's savings: 23.6486
[174] val's savings: 23.6486
[175] val's savings: 23.6486
[176] val's savings: 23.6486
[177] val's savings: 23.6486
[178] val's savings: 23.6486
[179] val's savings: 23.6486
Early stopping, best iteration is:
[79] val's savings: 23.6486

@StrikerRUS
Copy link
Collaborator

StrikerRUS commented Sep 1, 2019

@pford221 Hi! This behavior is expected.

From your example

early_stopping_rounds = 100

so, 121 - 100 = 21.

Use mod.best_iteration to access the best iteration.

@pford221
Copy link
Contributor Author

pford221 commented Sep 1, 2019

@StrikerRUS thanks for the prompt reply, but what makes it expected? It's not returning the best iteration as determined by the validation set. That's the problem not the fact that 21 + 100 != 121.

See the edit to the issue in Other Notes for expected (or at least desirable) behavior.

mod.best_iteration would return 21, which is not the best iteration.

@StrikerRUS
Copy link
Collaborator

mod.best_iteration would return 21, which is not the best iteration.

Why do you think that 21 is not the best iteration? Please post the entire recorded log to check it.

@pford221
Copy link
Contributor Author

pford221 commented Sep 1, 2019

Because the thing we're trying to maximize on the validation set, feval, increases at iteration 79.

[1] trn's savings: 16.6667 val's savings: 5.40541
Training until validation scores don't improve for 100 rounds.
[2] trn's savings: 30.5556 val's savings: 21.6216
[3] trn's savings: 34.127 val's savings: 14.8649
[4] trn's savings: 36.9048 val's savings: 14.8649
[5] trn's savings: 36.1111 val's savings: 16.8919
[6] trn's savings: 34.5238 val's savings: 18.2432
[7] trn's savings: 35.7143 val's savings: 18.2432
[8] trn's savings: 34.5238 val's savings: 20.2703
[9] trn's savings: 32.1429 val's savings: 18.2432
[10] trn's savings: 36.9048 val's savings: 20.2703
[11] trn's savings: 40.0794 val's savings: 20.2703
[12] trn's savings: 44.8413 val's savings: 22.2973
[13] trn's savings: 44.8413 val's savings: 19.5946
[14] trn's savings: 43.254 val's savings: 19.5946
[15] trn's savings: 44.8413 val's savings: 21.6216
[16] trn's savings: 44.8413 val's savings: 21.6216
[17] trn's savings: 43.6508 val's savings: 21.6216
[18] trn's savings: 42.0635 val's savings: 21.6216
[19] trn's savings: 44.8413 val's savings: 19.5946
[20] trn's savings: 44.8413 val's savings: 21.6216
[21] trn's savings: 46.0317 val's savings: 21.6216
[22] trn's savings: 46.0317 val's savings: 21.6216
[23] trn's savings: 44.4444 val's savings: 21.6216
[24] trn's savings: 42.8571 val's savings: 21.6216
[25] trn's savings: 42.8571 val's savings: 21.6216
[26] trn's savings: 42.8571 val's savings: 21.6216
[27] trn's savings: 42.8571 val's savings: 21.6216
[28] trn's savings: 42.8571 val's savings: 21.6216
[29] trn's savings: 41.2698 val's savings: 21.6216
[30] trn's savings: 42.8571 val's savings: 21.6216
[31] trn's savings: 42.8571 val's savings: 21.6216
[32] trn's savings: 42.8571 val's savings: 21.6216
[33] trn's savings: 42.8571 val's savings: 21.6216
[34] trn's savings: 41.2698 val's savings: 21.6216
[35] trn's savings: 42.8571 val's savings: 21.6216
[36] trn's savings: 42.8571 val's savings: 21.6216
[37] trn's savings: 42.8571 val's savings: 21.6216
[38] trn's savings: 42.8571 val's savings: 21.6216
[39] trn's savings: 42.8571 val's savings: 21.6216
[40] trn's savings: 42.8571 val's savings: 21.6216
[41] trn's savings: 42.8571 val's savings: 21.6216
[42] trn's savings: 42.8571 val's savings: 21.6216
[43] trn's savings: 42.8571 val's savings: 21.6216
[44] trn's savings: 42.8571 val's savings: 21.6216
[45] trn's savings: 42.8571 val's savings: 21.6216
[46] trn's savings: 42.8571 val's savings: 21.6216
[47] trn's savings: 42.8571 val's savings: 21.6216
[48] trn's savings: 42.8571 val's savings: 21.6216
[49] trn's savings: 42.8571 val's savings: 21.6216
[50] trn's savings: 42.8571 val's savings: 21.6216
[51] trn's savings: 42.8571 val's savings: 21.6216
[52] trn's savings: 42.8571 val's savings: 21.6216
[53] trn's savings: 42.8571 val's savings: 21.6216
[54] trn's savings: 42.8571 val's savings: 21.6216
[55] trn's savings: 42.8571 val's savings: 21.6216
[56] trn's savings: 41.6667 val's savings: 21.6216
[57] trn's savings: 41.6667 val's savings: 21.6216
[58] trn's savings: 41.6667 val's savings: 21.6216
[59] trn's savings: 41.6667 val's savings: 21.6216
[60] trn's savings: 41.6667 val's savings: 21.6216
[61] trn's savings: 42.8571 val's savings: 21.6216
[62] trn's savings: 42.8571 val's savings: 21.6216
[63] trn's savings: 42.8571 val's savings: 21.6216
[64] trn's savings: 42.8571 val's savings: 21.6216
[65] trn's savings: 41.6667 val's savings: 21.6216
[66] trn's savings: 41.6667 val's savings: 21.6216
[67] trn's savings: 41.6667 val's savings: 21.6216
[68] trn's savings: 41.6667 val's savings: 21.6216
[69] trn's savings: 42.8571 val's savings: 21.6216
[70] trn's savings: 42.8571 val's savings: 21.6216
[71] trn's savings: 42.8571 val's savings: 21.6216
[72] trn's savings: 42.8571 val's savings: 21.6216
[73] trn's savings: 42.8571 val's savings: 21.6216
[74] trn's savings: 42.8571 val's savings: 21.6216
[75] trn's savings: 42.8571 val's savings: 21.6216
[76] trn's savings: 42.8571 val's savings: 21.6216
[77] trn's savings: 41.6667 val's savings: 21.6216
[78] trn's savings: 42.8571 val's savings: 21.6216
[79] trn's savings: 42.8571 val's savings: 23.6486
[80] trn's savings: 42.8571 val's savings: 23.6486
[81] trn's savings: 42.8571 val's savings: 23.6486
[82] trn's savings: 42.8571 val's savings: 23.6486
[83] trn's savings: 42.8571 val's savings: 23.6486
[84] trn's savings: 42.8571 val's savings: 23.6486
[85] trn's savings: 42.8571 val's savings: 23.6486
[86] trn's savings: 42.8571 val's savings: 23.6486
[87] trn's savings: 42.8571 val's savings: 23.6486
[88] trn's savings: 42.8571 val's savings: 23.6486
[89] trn's savings: 42.8571 val's savings: 23.6486
[90] trn's savings: 42.8571 val's savings: 23.6486
[91] trn's savings: 42.8571 val's savings: 23.6486
[92] trn's savings: 42.8571 val's savings: 23.6486
[93] trn's savings: 42.8571 val's savings: 23.6486
[94] trn's savings: 42.8571 val's savings: 23.6486
[95] trn's savings: 42.8571 val's savings: 23.6486
[96] trn's savings: 42.8571 val's savings: 23.6486
[97] trn's savings: 42.8571 val's savings: 23.6486
[98] trn's savings: 42.8571 val's savings: 23.6486
[99] trn's savings: 42.8571 val's savings: 23.6486
[100] trn's savings: 42.8571 val's savings: 23.6486
[101] trn's savings: 42.8571 val's savings: 23.6486
[102] trn's savings: 42.8571 val's savings: 23.6486
[103] trn's savings: 41.6667 val's savings: 23.6486
[104] trn's savings: 43.254 val's savings: 23.6486
[105] trn's savings: 43.254 val's savings: 23.6486
[106] trn's savings: 43.254 val's savings: 23.6486
[107] trn's savings: 43.254 val's savings: 23.6486
[108] trn's savings: 43.254 val's savings: 23.6486
[109] trn's savings: 44.4444 val's savings: 23.6486
[110] trn's savings: 44.4444 val's savings: 23.6486
[111] trn's savings: 43.254 val's savings: 23.6486
[112] trn's savings: 43.254 val's savings: 23.6486
[113] trn's savings: 43.254 val's savings: 23.6486
[114] trn's savings: 43.254 val's savings: 23.6486
[115] trn's savings: 43.254 val's savings: 23.6486
[116] trn's savings: 43.254 val's savings: 23.6486
[117] trn's savings: 43.254 val's savings: 23.6486
[118] trn's savings: 43.254 val's savings: 23.6486
[119] trn's savings: 43.254 val's savings: 23.6486
[120] trn's savings: 43.254 val's savings: 23.6486
[121] trn's savings: 43.254 val's savings: 23.6486
Early stopping, best iteration is:
[21] trn's savings: 46.0317 val's savings: 21.6216

@pford221
Copy link
Contributor Author

pford221 commented Sep 1, 2019

OK, I think I see the issue. The training set trn is also being monitored at the same time as the validation set val. If either of those don't increase after 100 iterations, then the training process stops.

Apologies for raising the issue, but does that mean that there is not a way to passively monitor the training error without having it count towards early stopping.

@StrikerRUS
Copy link
Collaborator

The expected behavior is described here:

Early stopping requires at least one set in valid_sets. If there is more than one, it will use all of them except the training data:
https://lightgbm.readthedocs.io/en/latest/Python-Intro.html#early-stopping

So, it's expected that train dataset shouldn't be taken into account.

According to your logs, it seems that something really goes wrong and early stopping criteria is met at training dataset.

@StrikerRUS StrikerRUS reopened this Sep 1, 2019
@StrikerRUS
Copy link
Collaborator

@guolinke
Copy link
Collaborator

guolinke commented Sep 2, 2019

@StrikerRUS
when the metric and objective are towards the same direction, the metric over training set could be always better.
But if they are not, it is hard to say.

@StrikerRUS
Copy link
Collaborator

@guolinke What do you think about really excluding training set from early stopping checks, like it's said in our docs?

@guolinke
Copy link
Collaborator

guolinke commented Sep 3, 2019

@StrikerRUS yeah, I think we can do it.

@StrikerRUS
Copy link
Collaborator

@guolinke Nice!
#2209 starts the work towards that:

if (((env.evaluation_result_list[i][0] == "cv_agg" 
       and env.evaluation_result_list[i][1].split(" ")[0] == "train")
      or env.evaluation_result_list[i][0] == env.model._train_data_name)):
     continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)

After merging it I'll help with the rest Python code.

@StrikerRUS
Copy link
Collaborator

@pford221 Seems that recently merged #2209 should fix all cases when training data can cause early stopping

if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
_final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)

Can you please check the latest master and provide your feedback?

@StrikerRUS
Copy link
Collaborator

@guolinke What about cpp code? Is it easy to ignore training data there?
For R-package I guess we need to create a separate feature request issue, as we have some delay in the R-package development.

@guolinke
Copy link
Collaborator

I think training metric is excluded in cpp side. refer to

for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter
<< ", training " << name[k]
<< " : " << scores[k];
Log::Info(tmp_buf.str().c_str());
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << '\n';
}
}
}
}

@StrikerRUS
Copy link
Collaborator

@guolinke Ah, nice!

And what about the case when user specifies path to training data into valid param? I know that for CLI version the right way to monitor metrics over training data is specifying is_provide_training_metric param. But it can happen by a mistake. Do we a have a guard from applying early stopping on training data when it was passed to validation sets?

@guolinke
Copy link
Collaborator

@StrikerRUS a quick fix for that is to check the file path of valid. if it is the same as the training data, we could set is_provide_training_metric to true, and set it to "".
For the same data with different file path, the only way is to check the content, but I think it is very costly, and don't worth to do that.

@StrikerRUS
Copy link
Collaborator

StrikerRUS commented Sep 22, 2019

@guolinke Yeah, path check is quite enough, I think! That was more my curiosity than something else. As we have no similar to Python API

If there's more than one, will check all of them. But the training data is ignored anyway.

(especially the word anyway) descriptions for the CLI version and it's incorrect usage, nothing should be done.

@StrikerRUS
Copy link
Collaborator

@pford221 Feel free to re-open if the issue is still presented in the 2.3.0 version.
For the R-package, I have created a separate issue #2472.

@lock lock bot locked as resolved and limited conversation to collaborators Mar 10, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants