diff --git a/distributed/diagnostics/tests/test_cudf_diagnostics.py b/distributed/diagnostics/tests/test_cudf_diagnostics.py index feb5681855..fd252590ab 100644 --- a/distributed/diagnostics/tests/test_cudf_diagnostics.py +++ b/distributed/diagnostics/tests/test_cudf_diagnostics.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os import pytest @@ -24,22 +25,27 @@ def force_spill(): manager = get_global_manager() - # 24 bytes + # Allocate a new dataframe and trigger spilling by setting a 1 byte limit df = cudf.DataFrame({"a": [1, 2, 3]}) + manager.spill_to_device_limit(1) - return manager.spill_to_device_limit(1) + # Get bytes spilled from GPU to CPU + spill_totals, _ = get_global_manager().statistics.spill_totals[("gpu", "cpu")] + return spill_totals @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)], ) -@pytest.mark.flaky(reruns=10, reruns_delay=5) async def test_cudf_metrics(c, s, *workers): w = list(s.workers.values())[0] assert "cudf" in w.metrics assert w.metrics["cudf"]["cudf-spilled"] == 0 - await c.run(force_spill) - - assert w.metrics["cudf"]["cudf-spilled"] == 24 + spill_totals = (await c.run(force_spill, workers=[w.address]))[w.address] + assert spill_totals > 0 + # We have to wait for the worker's metrics to update. + # TODO: avoid sleep, is it possible to wait on the next update of metrics? + await asyncio.sleep(1) + assert w.metrics["cudf"]["cudf-spilled"] == spill_totals