diff --git a/tests/conftest.py b/tests/conftest.py index 56fe7075..19db7a80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ import gc import glob import random +from unittest.mock import patch import dask import pandas as pd @@ -23,6 +24,7 @@ from merlin.core.compat import cudf from merlin.core.compat import numpy as np +from merlin.dataloader.loader_base import LoaderBase if cudf: try: @@ -241,3 +243,17 @@ def np_embeddings_from_pq(rev_embeddings_from_dataframe, tmpdir_factory): paths, embeddings_file, lookup_ids_file ) return npy_filename, lookup_filename + + +@pytest.fixture(scope="function", autouse=True) +def cleanup_dataloader(): + """After each test runs. Call .stop() on any dataloaders created during the test. + The avoids issues with background threads hanging around and interfering with subsequent tests. + This happens when a dataloader is partially consumed (not all batches are iterated through). + """ + with patch.object( + LoaderBase, "__iter__", side_effect=LoaderBase.__iter__, autospec=True + ) as patched: + yield + for call in patched.call_args_list: + call.args[0].stop()