Skip to content

Commit

Permalink
update multiclass tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jan 13, 2022
1 parent 8abb932 commit 54df09e
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,16 @@ def test_classifier_custom_objective(output, task, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task,
output=output
output=output,
)

params = {
"n_estimators": 50,
"num_leaves": 31,
"min_data": 1,
"verbose": -1,
"learning_rate": 0.01,
"seed": 708,
"deterministic": True,
"force_col_wise": True
}

if task == 'binary-classification':
Expand All @@ -522,25 +523,26 @@ def test_classifier_custom_objective(output, task, cluster):
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
dask_classifier_local = dask_classifier.to_local()
p1_proba = dask_classifier.predict_proba(dX).compute()
p1_proba_local = dask_classifier_local.predict_proba(X)
p1_raw = dask_classifier.predict(dX, raw_score=True).compute()
p1_raw_local = dask_classifier_local.predict(X, raw_score=True)

# with a custom objective, prediction result is a raw score instead of predicted class
p1_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5
p1_class = p1_class.astype(np.int64)
p1_class_local = (1.0 / (1.0 + np.exp(-p1_proba_local))) > 0.5
p1_class_local = p1_class_local.astype(np.int64)
p1_proba = 1.0 / (1.0 + np.exp(-p1_raw))
p1_proba_local = 1.0 / (1.0 + np.exp(-p1_raw_local))

local_classifier = lgb.LGBMClassifier(**params)
local_classifier.fit(X, y, sample_weight=w)
p2_proba = local_classifier.predict_proba(X)
p2_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5
p2_class = p2_class.astype(np.int64)
p2_raw = local_classifier.predict(X, raw_score=True)
p2_proba = 1.0 / (1.0 + np.exp(-p2_raw))

if task == 'multiclass-classification':
p1_class = p1_class.argmax(axis=1)
p1_class_local = p1_class_local.argmax(axis=1)
p2_class = p2_class.argmax(axis=1)
if task == 'binary-classification':
p1_class = (p1_proba > 0.5).astype(np.int64)
p1_class_local = (p1_proba_local > 0.5).astype(np.int64)
p2_class = (p2_proba > 0.5).astype(np.int64)
elif task == 'multiclass-classification':
p1_class = p1_proba.argmax(axis=1)
p1_class_local = p1_proba_local.argmax(axis=1)
p2_class = p2_proba.argmax(axis=1)

# function should have been preserved
assert callable(dask_classifier.objective_)
Expand All @@ -552,7 +554,7 @@ def test_classifier_custom_objective(output, task, cluster):
assert_eq(p2_class, y)

# probability estimates should be similar
assert_eq(p1_proba, p2_proba, atol=0.03)
assert_eq(p1_proba, p2_proba, atol=0.04)


def test_group_workers_by_host():
Expand Down

0 comments on commit 54df09e

Please sign in to comment.