diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index ae4e3332c..d9cd6dfb2 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -17,6 +17,7 @@ import dask_cuda from dask_cuda.explicit_comms import comms from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle +from dask_cuda.local_cuda_cluster import IncreasedCloseTimeoutNanny mp = mp.get_context("spawn") # type: ignore ucp = pytest.importorskip("ucp") @@ -35,6 +36,7 @@ def _test_local_cluster(protocol): dashboard_address=None, n_workers=4, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster) as client: @@ -56,6 +58,7 @@ def _test_dataframe_merge_empty_partitions(nrows, npartitions): dashboard_address=None, n_workers=npartitions, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -102,6 +105,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): dashboard_address=None, n_workers=n_workers, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster) as client: @@ -204,6 +208,7 @@ def check_shuffle(): dashboard_address=None, n_workers=2, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -221,6 +226,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers): dashboard_address=None, n_workers=n_workers, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -327,6 +333,7 @@ def test_lock_workers(): dashboard_address=None, n_workers=4, threads_per_worker=5, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: ps = []