Skip to content

Commit

Permalink
[tests][dask] fix workers without data test (fixes #5537) (#5544)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 21, 2022
1 parent 2d4654a commit 93f2da4
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from urllib.parse import urlparse

import pytest
from sklearn.metrics import accuracy_score, r2_score

import lightgbm as lgb

Expand Down Expand Up @@ -75,6 +76,13 @@ def cluster2():
dask_cluster.close()


@pytest.fixture(scope='module')
def cluster_three_workers():
dask_cluster = LocalCluster(n_workers=3, threads_per_worker=1, dashboard_address=None)
yield dask_cluster
dask_cluster.close()


@pytest.fixture()
def listen_port():
listen_port.port += 10
Expand Down Expand Up @@ -1503,56 +1511,54 @@ def f(part):

@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster):
pytest.skip("skipping due to timeout issues discussed in https://github.com/microsoft/LightGBM/pull/5510")
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster_three_workers):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')

with Client(cluster) as client:
def collection_to_single_partition(collection):
"""Merge the parts of a Dask collection into a single partition."""
if collection is None:
return
if isinstance(collection, da.Array):
return collection.rechunk(*collection.shape)
return collection.repartition(npartitions=1)

X, y, w, g, dX, dy, dw, dg = _create_data(
with Client(cluster_three_workers) as client:
_, y, _, _, dX, dy, dw, dg = _create_data(
objective=task,
output=output,
group=None
group=None,
n_samples=1_000,
chunk_size=200,
)

dask_model_factory = task_to_dask_factory[task]
local_model_factory = task_to_local_factory[task]

dX = collection_to_single_partition(dX)
dy = collection_to_single_partition(dy)
dw = collection_to_single_partition(dw)
dg = collection_to_single_partition(dg)
workers = list(client.scheduler_info()['workers'].keys())
assert len(workers) == 3
first_two_workers = workers[:2]

n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
assert dX.npartitions == 1
dX = client.persist(dX, workers=first_two_workers)
dy = client.persist(dy, workers=first_two_workers)
dw = client.persist(dw, workers=first_two_workers)
wait([dX, dy, dw])

workers_with_data = set()
for coll in (dX, dy, dw):
for with_data in client.who_has(coll).values():
workers_with_data.update(with_data)
assert workers[2] not in with_data
assert len(workers_with_data) == 2

params = {
'time_out': 5,
'random_state': 42,
'num_leaves': 10
'num_leaves': 10,
'n_estimators': 20,
}

dask_model = dask_model_factory(tree='data', client=client, **params)
dask_model.fit(dX, dy, group=dg, sample_weight=dw)
dask_preds = dask_model.predict(dX).compute()

local_model = local_model_factory(**params)
if task == 'ranking':
local_model.fit(X, y, group=g, sample_weight=w)
if task == 'regression':
score = r2_score(y, dask_preds)
elif task.endswith('classification'):
score = accuracy_score(y, dask_preds)
else:
local_model.fit(X, y, sample_weight=w)
local_preds = local_model.predict(X)

assert assert_eq(dask_preds, local_preds)
score = spearmanr(dask_preds, y).correlation
assert score > 0.9


@pytest.mark.parametrize('task', tasks)
Expand Down

0 comments on commit 93f2da4

Please sign in to comment.