diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 0125fa50bc14..34eb4b0b2093 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -48,7 +48,8 @@ jobs: - name: Unit tests run: | + TEST_LOG_FILE="/tmp/test_log_cpu_${GITHUB_RUN_ID}.log" unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.5" + RUNNING_TEST_LOG_FILE=${TEST_LOG_FILE} DS_UNITTEST_FILE_STORE_DIR=/dev/shm HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.5" HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.5" diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 0b8f504d8b5a..8b2e3eb6a528 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -19,7 +19,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu121, v100] + runs-on: [self-hosted, nvidia, cu121, v100] # Modified to run on the test runner steps: - uses: actions/checkout@v4 @@ -44,7 +44,7 @@ jobs: - name: Install deepspeed run: | - pip install .[dev,1bit,autotuning] + pip install .[dev,1bit,1bit-mpi,autotuning] ds_report - name: Python environment @@ -55,5 +55,26 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.5" --cuda_ver="12.1" - pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.5" --cuda_ver="12.1" + TEST_LOG_FILE="/tmp/test_log_${GITHUB_RUN_ID}.log" + echo "Running tests and logging to ${TEST_LOG_FILE}" + # Let this line return true so that we can grep for "Failed" in the log file + set +e + pytest -s unit/comm/test_dist.py::TestDistInferenceAllReduce + NCCL_SOCKET_IFNAME="" DS_UNITTEST_FILE_STORE_DIR=/dev/shm RUNNING_TEST_LOG_FILE=${TEST_LOG_FILE} pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.5" --cuda_ver="12.1" + PYTEST_EXIT_CODE=$? + if [ $PYTEST_EXIT_CODE -ne 0 ]; then + # We don't clean the file here for debugging + echo "pytest failed with exit code $PYTEST_EXIT_CODE" + exit $PYTEST_EXIT_CODE + fi + grep "Failed" ${TEST_LOG_FILE} + rm -f ${TEST_LOG_FILE} + # Do the same as above + DS_UNITTEST_FILE_STORE_DIR=/dev/shm RUNNING_TEST_LOG_FILE=${TEST_LOG_FILE} pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.5" --cuda_ver="12.1" + PYTEST_EXIT_CODE=$? + grep "Failed" ${TEST_LOG_FILE} + if [ $PYTEST_EXIT_CODE -ne 0 ]; then + echo "pytest failed with exit code $PYTEST_EXIT_CODE" + exit $PYTEST_EXIT_CODE + fi + rm -f ${TEST_LOG_FILE} diff --git a/tests/conftest.py b/tests/conftest.py index 45e8434a021b..6a35cfe177cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,13 +70,47 @@ def pytest_runtest_call(item): item.runtest = lambda: True # Dummy function so test is not run twice +def write_to_log_with_lock(log_file_path: str, header: str, msg: str): + import fcntl + with open(log_file_path, 'a+') as f: + try: + fcntl.flock(f, fcntl.LOCK_EX) + f.write(f"{header} {msg}\n") + f.flush() + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + +dist_test_class = None + + # We allow DistributedTest to reuse distributed environments. When the last # test for a class is run, we want to make sure those distributed environments # are destroyed. def pytest_runtest_teardown(item, nextitem): - if getattr(item.cls, "reuse_dist_env", False) and not nextitem: + RUNNING_TEST_LOG_FILE = os.environ.get("RUNNING_TEST_LOG_FILE", "/tmp/running_test.log") + + global dist_test_class + # Last test might not have .cls. So we record the pool_cache here + if item.cls is not None: dist_test_class = item.cls() + + def get_xdist_worker_id(): + xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None) + if xdist_worker is not None: + xdist_worker_id = xdist_worker.replace('gw', '') + return int(xdist_worker_id) + return None + + if RUNNING_TEST_LOG_FILE: + reuse_dist_env = getattr(item.cls, "reuse_dist_env", False) + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, f"pytest_runtest_teardown,xdist={get_xdist_worker_id()}", + f"reuse_dist_env={reuse_dist_env} nextitem={nextitem}") + + if not nextitem and dist_test_class is not None and dist_test_class._pool_cache is not None: for num_procs, pool in dist_test_class._pool_cache.items(): + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, f"pytest_runtest_teardown,xdist={get_xdist_worker_id()}", + f"closing pool num_procs={num_procs} nextitem={nextitem}") dist_test_class._close_pool(pool, num_procs, force=True) diff --git a/tests/unit/comm/test_dist.py b/tests/unit/comm/test_dist.py index 861ba5c7be1a..1cd6cc11212f 100644 --- a/tests/unit/comm/test_dist.py +++ b/tests/unit/comm/test_dist.py @@ -112,12 +112,7 @@ def test(self, distributed_fixture, class_tmpdir, val1, val2): class TestDistAllReduce(DistributedTest): device_count = get_accelerator().device_count() - if device_count >= 4: - world_size = [1, 2, 4] - elif device_count >= 2: - world_size = [1, 2] - else: - world_size = [1] + world_size = 2 def test(self): x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) @@ -130,12 +125,7 @@ def test(self): @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) class TestDistInferenceAllReduce(DistributedTest): device_count = get_accelerator().device_count() - if device_count >= 4: - world_size = [1, 2, 4] - elif device_count >= 2: - world_size = [1, 2] - else: - world_size = [1] + world_size = 2 def test(self, dtype): x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) @@ -143,7 +133,9 @@ def test(self, dtype): result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks result = result.to(dtype) x = x.to(dtype) + print(f"Rank {dist.get_rank()} x: {x}") dist.inference_all_reduce(x) + print(f"AR Rank {dist.get_rank()} x: {x}") assert torch.all(x == result) diff --git a/tests/unit/common.py b/tests/unit/common.py index 1498b0400ee1..c447b22f57f5 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -11,6 +11,9 @@ import subprocess from abc import ABC, abstractmethod from pathlib import Path +import fcntl +import traceback +from enum import Enum import torch import torch.multiprocessing as mp @@ -24,6 +27,16 @@ # Worker timeout for tests that hang DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600')) +RUNNING_TEST_LOG_FILE = os.environ.get("RUNNING_TEST_LOG_FILE", None) +DS_UNITTEST_FILE_STORE_DIR = os.environ.get("DS_UNITTEST_FILE_STORE_DIR", None) + + +class TestResultType(Enum): + SUCCESS = 0 + UNSET = 1 + ERROR = 2 + SKIP = 3 + TIMEOUT = 4 def is_rocm_pytorch(): @@ -126,6 +139,80 @@ def set_accelerator_visible(): os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list) +def write_to_log_with_lock(log_file_path: str, header: str, msg: str): + with open(log_file_path, 'a+') as f: + try: + fcntl.flock(f, fcntl.LOCK_EX) + f.write(f"{header} {msg}\n") + f.flush() + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + +def make_test_tag(request): + if request is None: + return "[xdist_worker={get_xdist_worker_id()}][NO_REQUEST]" + + class_name = request.cls.__name__ if request.cls else "NO_CLASS" + test_name = request.node.name + return f"[xdist_worker={get_xdist_worker_id()}][{class_name}][{test_name}]" + + +class LogTestRun(ABC): + + def __init__(self, log_file, tag, num_procs): + self.log_file = log_file + self.num_procs = num_procs + self.header = tag + + def write(self, msg): + write_to_log_with_lock(self.log_file, self.header, msg) + + def __enter__(self): + if self.log_file is None: + return + self._enter() + self.start_time = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.log_file is None: + return + + self.elapsed_time = time.time() - self.start_time + self._exit(exc_type, exc_val, exc_tb) + + @abstractmethod + def _enter(self): + ... + + @abstractmethod + def _exit(self, exc_type, exc_val, exc_tb): + ... + + +class LogTestRunBaseProcess(LogTestRun): + + def __init__(self, log_file, tag, num_procs): + super().__init__(log_file, tag, num_procs) + + def _enter(self): + self.write(f"Running with {self.num_procs} processes") + + def _exit(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + tb_str = ''.join(traceback.format_tb(exc_tb)) + if exc_type == Skipped: + self.write( + f"Skipping with {self.num_procs} processes. elapsed_time={self.elapsed_time:.2f}s exc_type={exc_type} exc_val={exc_val}" + ) + else: + self.write( + f"Failed with {self.num_procs} processes. elapsed_time={self.elapsed_time:.2f}s exc_type={exc_type} exc_val={exc_val} {tb_str}" + ) + return False + self.write(f"Finished with {self.num_procs} processes. elapsed_time={self.elapsed_time:.2f}s") + + class DistributedExec(ABC): """ Base class for distributed execution of functions/methods. Contains common @@ -136,7 +223,7 @@ class DistributedExec(ABC): init_distributed = True set_dist_env = True requires_cuda_env = True - reuse_dist_env = False + reuse_dist_env = True non_daemonic_procs = False _pool_cache = {} exec_timeout = DEEPSPEED_TEST_TIMEOUT @@ -151,7 +238,8 @@ def __call__(self, request): if self.requires_cuda_env and not get_accelerator().is_available(): pytest.skip("only supported in accelerator environments.") - self._launch_with_file_store(request, world_size) + tag = make_test_tag(request) + self._launch_with_file_store(request, world_size, tag) def _get_fixture_kwargs(self, request, func): if not request: @@ -167,7 +255,7 @@ def _get_fixture_kwargs(self, request, func): pass # test methods can have kwargs that are not fixtures return fixture_kwargs - def _launch_daemonic_procs(self, num_procs, init_method): + def _launch_daemonic_procs(self, num_procs, init_method, tag): # Create process pool or use cached one master_port = None @@ -186,27 +274,68 @@ def _launch_daemonic_procs(self, num_procs, init_method): master_port = get_master_port() # Run the test - args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)] - skip_msgs_async = pool.starmap_async(self._dist_run, args) + args = [(local_rank, num_procs, master_port, init_method, tag) for local_rank in range(num_procs)] + + if RUNNING_TEST_LOG_FILE: + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, tag, + f"Starting child processes: reuse_dist_env={self.reuse_dist_env}") + RETRY_COUNT = 10 + fork_process_result = TestResultType.UNSET try: - skip_msgs = skip_msgs_async.get(self.exec_timeout) - except mp.TimeoutError: - # Shortcut to exit pytest in the case of a hanged test. This - # usually means an environment error and the rest of tests will - # hang (causing super long unit test runtimes) - pytest.exit("Test hanged, exiting", returncode=1) + for _ in range(RETRY_COUNT): + try: + skip_msgs_async = pool.starmap_async(self._dist_run, args) + test_results = skip_msgs_async.get(self.exec_timeout) + + if any("NCCL error" in msg for result_type, msg in test_results + if result_type == TestResultType.ERROR): + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, tag, + f"NCCL error in _launch_daemonic_procs, retrying") + # will be caught by the except block below + raise RuntimeError("NCCL error") + + fork_process_result = TestResultType.SUCCESS + break + except mp.TimeoutError as e: + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, tag, + f"Timeout in _launch_daemonic_procs: {e} retrying") + fork_process_result = TestResultType.TIMEOUT + # pytest.exit("Test hanged, exiting", returncode=1) + except Exception as e: + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, tag, + f"Exception in _launch_daemonic_procs: {e} retrying") + fork_process_result = TestResultType.ERROR + self._close_pool(pool, num_procs) + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, tag, f"Pool closed") + # Must be shorter enough than DEEPSPEED_TEST_TIMEOUT + time.sleep(10 + 10 * torch.rand(1).item()) + pool = mp.Pool(processes=num_procs) + + if self.reuse_dist_env: + self._pool_cache[num_procs] = pool finally: # Regardless of the outcome, ensure proper teardown # Tear down distributed environment and close process pools self._close_pool(pool, num_procs) + if RUNNING_TEST_LOG_FILE: + write_to_log_with_lock(RUNNING_TEST_LOG_FILE, tag, f"Child processes finished: {fork_process_result}") + if fork_process_result == TestResultType.TIMEOUT or fork_process_result == TestResultType.ERROR: + pytest.fail(f"Test failed with error: {fork_process_result}", pytrace=False) + # If we skipped a test, propagate that to this process + + skip_msgs = [msg for result_type, msg in test_results if result_type == TestResultType.SKIP] if any(skip_msgs): assert len(set(skip_msgs)) == 1, "Multiple different skip messages received" pytest.skip(skip_msgs[0]) - def _launch_non_daemonic_procs(self, num_procs, init_method): + err_msgs = [msg for result_type, msg in test_results if result_type == TestResultType.ERROR] + if any(err_msgs): + pytest.fail(f"Test failed with error: {err_msgs[0]}", pytrace=False) + + def _launch_non_daemonic_procs(self, num_procs, init_method, tag): assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes" master_port = get_master_port() @@ -215,7 +344,8 @@ def _launch_non_daemonic_procs(self, num_procs, init_method): prev_start_method = mp.get_start_method() mp.set_start_method('spawn', force=True) for local_rank in range(num_procs): - p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg)) + p = mp.Process(target=self._dist_run, + args=(local_rank, num_procs, master_port, init_method, tag, skip_msg)) p.start() processes.append(p) mp.set_start_method(prev_start_method, force=True) @@ -257,7 +387,7 @@ def _launch_non_daemonic_procs(self, num_procs, init_method): # add a check here to assert all exit messages are equal pytest.skip(skip_msg.get()) - def _launch_procs(self, num_procs, init_method): + def _launch_procs(self, num_procs, init_method, tag): # Verify we have enough accelerator devices to run this test if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: pytest.skip( @@ -266,66 +396,168 @@ def _launch_procs(self, num_procs, init_method): if get_accelerator().device_name() == 'xpu': self.non_daemonic_procs = True + + if self.non_daemonic_procs: self.reuse_dist_env = False + if RUNNING_TEST_LOG_FILE: + write_to_log_with_lock( + RUNNING_TEST_LOG_FILE, tag, + f"_launch_procs non_daemonic_procs={self.non_daemonic_procs} reuse_dist_env={self.reuse_dist_env}") + # Set start method to `forkserver` (or `fork`) mp.set_start_method('forkserver', force=True) if self.non_daemonic_procs: - self._launch_non_daemonic_procs(num_procs, init_method) + self._launch_non_daemonic_procs(num_procs, init_method, tag) else: - self._launch_daemonic_procs(num_procs, init_method) + self._launch_daemonic_procs(num_procs, init_method, tag) + + def init_process_group_exclusively(self, local_rank, num_procs, init_method): + xdist_worker_id = get_xdist_worker_id() + xdist_worker_id = xdist_worker_id if xdist_worker_id is not None else -1 + RETRY_INTERVAL = 1 + LOCK_FILE_NAME = "worker_dist_init.lock" + + def acquire_lock_with_pgid(worker_id): + if local_rank == 0: + import errno + try: + fd = os.open(LOCK_FILE_NAME, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + os.write(fd, str(worker_id).encode()) + os.close(fd) + # print(f"Lock acquired by process group {worker_id}.") + return True + except OSError as e: + if e.errno == errno.EEXIST: + return False + else: + raise e + else: + try: + with open(LOCK_FILE_NAME, "r") as f: + existing_wid = int(f.read().strip()) + return existing_wid == xdist_worker_id + except FileNotFoundError: + return False + + def release_lock(): + try: + os.remove(LOCK_FILE_NAME) + except FileNotFoundError: + print("Lock file already deleted.") - def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): - if dist.is_initialized(): - if get_accelerator().is_available(): - # local_rank might not match the rank in the previous run if you are reusing the environment - get_accelerator().set_device(dist.get_rank()) - else: - """ Initialize deepspeed.comm and execute the user function. """ - if self.set_dist_env: - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(master_port) - os.environ['LOCAL_RANK'] = str(local_rank) - # NOTE: unit tests don't support multi-node so local_rank == global rank - os.environ['RANK'] = str(local_rank) - # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE - # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly - os.environ['LOCAL_SIZE'] = str(num_procs) - os.environ['WORLD_SIZE'] = str(num_procs) - - # turn off NCCL logging if set - os.environ.pop('NCCL_DEBUG', None) - - if get_accelerator().is_available(): - set_accelerator_visible() - - if get_accelerator().is_available(): - get_accelerator().set_device(local_rank) - - if self.init_distributed: - deepspeed.init_distributed(dist_backend=self.backend, - init_method=init_method, - rank=local_rank, - world_size=num_procs) - dist.barrier() + while not acquire_lock_with_pgid(xdist_worker_id): + time.sleep(RETRY_INTERVAL) try: - self.run(**self._fixture_kwargs) + print("Processing with lock...") + from datetime import timedelta + timeout = timedelta(seconds=60) + deepspeed.init_distributed(dist_backend=self.backend, + init_method=init_method, + rank=local_rank, + world_size=num_procs, + timeout=timeout) + dist.broadcast(torch.tensor([0], device=get_accelerator().current_device()), 0) + dist.barrier() + print("Processing completed.") + + finally: + if local_rank == 0: + release_lock() + + def _dist_run(self, local_rank, num_procs, master_port, init_method, tag, skip_msg=""): + tag = f"{tag} [pid={os.getpid()},master_port={master_port},local_rank={local_rank},num_procs={num_procs}]" + prev_current_device = get_accelerator().current_device() + current_device = -0 + with LogTestRunBaseProcess( + RUNNING_TEST_LOG_FILE, + f"{tag} [setup _dist_run][dist_initialized={dist.is_initialized()},set_dist_env={self.set_dist_env},init_distributed={self.init_distributed},backend={self.backend},init_method={init_method}]", + num_procs): + if dist.is_initialized(): + if get_accelerator().is_available(): + # local_rank might not match the rank in the previous run if you are reusing the environment + get_accelerator().set_device(dist.get_rank()) + else: + """ Initialize deepspeed.comm and execute the user function. """ + if self.set_dist_env: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(master_port) + os.environ['LOCAL_RANK'] = str(local_rank) + # NOTE: unit tests don't support multi-node so local_rank == global rank + os.environ['RANK'] = str(local_rank) + # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE + # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly + os.environ['LOCAL_SIZE'] = str(num_procs) + os.environ['WORLD_SIZE'] = str(num_procs) + + # turn off NCCL logging if set + os.environ.pop('NCCL_DEBUG', None) + + if get_accelerator().is_available(): + set_accelerator_visible() + + if get_accelerator().is_available(): + get_accelerator().set_device(local_rank) + + print(f"self.init_distributed={self.init_distributed}, dist.is_initialized()={dist.is_initialized()}") + if self.init_distributed and not dist.is_initialized(): + try: + from datetime import timedelta + + deepspeed.init_distributed(dist_backend=self.backend, + init_method=init_method, + rank=local_rank, + world_size=num_procs, + timeout=timedelta(seconds=60)) + dist.broadcast(torch.tensor([0], device=get_accelerator().current_device()), 0) + dist.barrier() + # self.init_process_group_exclusively(local_rank, num_procs, init_method) + except BaseException as e: + msg = e.msg if "msg" in dir(e) else str(e) + return TestResultType.ERROR, msg + + current_device = get_accelerator().current_device() + + visible_devs = os.environ.get("CUDA_VISIBLE_DEVICES", None) + + test_result = TestResultType.UNSET + try: + with LogTestRunBaseProcess( + RUNNING_TEST_LOG_FILE, + f"{tag} [exec _dist_run][prev_dev={prev_current_device},dev={current_device},visible_devs=[{visible_devs}]]", + num_procs): + self.run(**self._fixture_kwargs) + test_result = TestResultType.SUCCESS except BaseException as e: - if isinstance(e, Skipped): + msg = e.msg if "msg" in dir(e) else str(e) + with LogTestRunBaseProcess(RUNNING_TEST_LOG_FILE, f"{tag} [exception _dist_run] {e.__class__} msg={msg}", + num_procs): + if isinstance(e, Skipped): + test_result = TestResultType.SKIP + else: + test_result = TestResultType.ERROR + if self.non_daemonic_procs: skip_msg.put(e.msg) else: - skip_msg = e.msg - else: - raise e + skip_msg = msg + + return test_result, skip_msg - return skip_msg + def _launch_with_file_store(self, request, world_size, tag): + import tempfile + + use_custom_file_store_dir = DS_UNITTEST_FILE_STORE_DIR is not None + if use_custom_file_store_dir: + shm_dir = tempfile.mkdtemp(prefix="ds_test_", dir="/dev/shm") + tmpdir = Path(shm_dir) + dist_file_store = tmpdir / "dist_file_store" + else: + tmpdir = request.getfixturevalue("tmpdir") + dist_file_store = tmpdir.join("dist_file_store") - def _launch_with_file_store(self, request, world_size): - tmpdir = request.getfixturevalue("tmpdir") - dist_file_store = tmpdir.join("dist_file_store") assert not os.path.exists(dist_file_store) init_method = f"file://{dist_file_store}" @@ -333,10 +565,12 @@ def _launch_with_file_store(self, request, world_size): world_size = [world_size] for procs in world_size: try: - self._launch_procs(procs, init_method) + self._launch_procs(procs, init_method, tag) finally: if os.path.exists(dist_file_store): os.remove(dist_file_store) + if use_custom_file_store_dir and os.path.exists(tmpdir): + os.rmdir(shm_dir) time.sleep(0.5) def _dist_destroy(self): @@ -346,8 +580,12 @@ def _dist_destroy(self): def _close_pool(self, pool, num_procs, force=False): if force or not self.reuse_dist_env: - msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)]) - pool.close() + ft_destroy = pool.starmap_async(self._dist_destroy, [() for _ in range(num_procs)]) + try: + ft_destroy.get(self.exec_timeout) + pool.close() + except mp.TimeoutError: + pool.terminate() pool.join() @@ -484,7 +722,8 @@ def __call__(self, request): else: world_size = self._fixture_kwargs.get("world_size", self.world_size) - self._launch_with_file_store(request, world_size) + tag = make_test_tag(request) + self._launch_with_file_store(request, world_size, tag) def _get_current_test_func(self, request): # DistributedTest subclasses may have multiple test methods