Skip to content

Commit

Permalink
[python] keep consistent state for Dataset fields (#2390)
Browse files Browse the repository at this point in the history
* keep consistent state for Dataset fields

* hotfix
  • Loading branch information
StrikerRUS authored Sep 9, 2019
1 parent f52be9b commit 9f6e441
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,7 @@ def set_label(self, label):
if self.handle is not None:
label = list_to_1d_numpy(_label_from_pandas(label), name='label')
self.set_field('label', label)
self.label = self.get_field('label') # original values can be modified at cpp side
return self

def set_weight(self, weight):
Expand All @@ -1360,6 +1361,7 @@ def set_weight(self, weight):
if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side
return self

def set_init_score(self, init_score):
Expand All @@ -1379,6 +1381,7 @@ def set_init_score(self, init_score):
if self.handle is not None and init_score is not None:
init_score = list_to_1d_numpy(init_score, np.float64, name='init_score')
self.set_field('init_score', init_score)
self.init_score = self.get_field('init_score') # original values can be modified at cpp side
return self

def set_group(self, group):
Expand Down
31 changes: 31 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,34 @@ def test_cegb_scaling_equalities(self):
with open(p2name, 'rt') as f:
p2txt = f.read()
self.assertEqual(p1txt, p2txt)

def test_consistent_state_for_dataset_fields(self):

def check_asserts(data):
np.testing.assert_allclose(data.label, data.get_label())
np.testing.assert_allclose(data.label, data.get_field('label'))
self.assertFalse(np.isnan(data.label[0]))
self.assertFalse(np.isinf(data.label[1]))
np.testing.assert_allclose(data.weight, data.get_weight())
np.testing.assert_allclose(data.weight, data.get_field('weight'))
self.assertFalse(np.isnan(data.weight[0]))
self.assertFalse(np.isinf(data.weight[1]))
np.testing.assert_allclose(data.init_score, data.get_init_score())
np.testing.assert_allclose(data.init_score, data.get_field('init_score'))
self.assertFalse(np.isnan(data.init_score[0]))
self.assertFalse(np.isinf(data.init_score[1]))
self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
data.label[0])))
self.assertAlmostEqual(data.label[1], data.weight[1])

X, y = load_breast_cancer(True)
sequence = np.ones(y.shape[0])
sequence[0] = np.nan
sequence[1] = np.inf
lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct()
check_asserts(lgb_data)
lgb_data = lgb.Dataset(X, y).construct()
lgb_data.set_label(sequence)
lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence)
check_asserts(lgb_data)

0 comments on commit 9f6e441

Please sign in to comment.