From 38a40cd9b7d7b14d67ed07d4d21647bbb2023d8e Mon Sep 17 00:00:00 2001 From: zdevito Date: Mon, 9 Jun 2025 10:13:04 -0700 Subject: [PATCH] [4/n tensor engine] testing for tensor engine hook in the actor mesh based controller to our test suite as an additional backend to suss out bugs Differential Revision: [D76171866](https://our.internmc.facebook.com/intern/diff/D76171866/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D76171866/)! [ghstack-poisoned] --- python/monarch/_testing.py | 67 ++++++++++++++++++++------- python/tests/builtins/test_log.py | 2 +- python/tests/builtins/test_random.py | 2 +- python/tests/test_coalescing.py | 2 +- python/tests/test_controller.py | 4 +- python/tests/test_remote_functions.py | 2 +- 6 files changed, 55 insertions(+), 24 deletions(-) diff --git a/python/monarch/_testing.py b/python/monarch/_testing.py index cf898577..61681fdb 100644 --- a/python/monarch/_testing.py +++ b/python/monarch/_testing.py @@ -10,7 +10,7 @@ import tempfile import time from contextlib import contextmanager, ExitStack -from typing import Callable, Generator, Optional +from typing import Callable, Dict, Generator, Literal, Optional import monarch_supervisor from monarch.common.client import Client @@ -18,6 +18,8 @@ from monarch.common.invocation import DeviceException, RemoteException from monarch.common.shape import NDSlice from monarch.controller.backend import ProcessBackend +from monarch.mesh_controller import spawn_tensor_engine +from monarch.proc_mesh import proc_mesh, ProcMesh from monarch.python_local_mesh import PythonLocalContext from monarch.rust_local_mesh import ( local_mesh, @@ -50,6 +52,7 @@ def __init__(self): self.cleanup = ExitStack() self._py_process_cache = {} self._rust_process_cache = None + self._proc_mesh_cache: Dict[Any, ProcMesh] = {} @contextmanager def _get_context(self, num_hosts, gpu_per_host): @@ -75,16 +78,14 @@ def _processes(self, num_hosts, gpu_per_host): @contextmanager def local_py_device_mesh( - self, num_hosts, gpu_per_host, activate=True + self, + num_hosts, + gpu_per_host, ) -> Generator[DeviceMesh, None, None]: ctx, hosts, processes = self._processes(num_hosts, gpu_per_host) dm = world_mesh(ctx, hosts, gpu_per_host, _processes=processes) try: - if activate: - with dm.activate(): - yield dm - else: - yield dm + yield dm dm.client.shutdown(destroy_pg=False) except Exception: # abnormal exit, so we just make sure we do not try to communicate in destructors, @@ -97,7 +98,6 @@ def local_rust_device_mesh( self, num_hosts, gpu_per_host, - activate: bool = True, controller_params=None, ) -> Generator[DeviceMesh, None, None]: # Create a new system and mesh for test. @@ -115,11 +115,7 @@ def local_rust_device_mesh( controller_params=controller_params, ) as dm: try: - if activate: - with dm.activate(): - yield dm - else: - yield dm + yield dm dm.exit() except Exception: dm.client._shutdown = True @@ -129,21 +125,56 @@ def local_rust_device_mesh( # pyre-ignore: Undefined attribute dm.client.inner._actor.stop() + @contextmanager + def local_engine_on_proc_mesh( + self, + num_hosts, + gpu_per_host, + ) -> Generator[DeviceMesh, None, None]: + key = (num_hosts, gpu_per_host) + if key not in self._proc_mesh_cache: + self._proc_mesh_cache[key] = proc_mesh( + hosts=num_hosts, gpus=gpu_per_host + ).get() + + dm = spawn_tensor_engine(self._proc_mesh_cache[key]) + dm = dm.rename(hosts="host", gpus="gpu") + try: + yield dm + except Exception as e: + # abnormal exit, so we just make sure we do not try to communicate in destructors, + # but we do notn wait for workers to exit since we do not know what state they are in. + dm.client._shutdown = True + raise + @contextmanager def local_device_mesh( - self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None + self, + num_hosts, + gpu_per_host, + activate=True, + backend: Literal["py", "rs", "mesh"] = "py", + controller_params=None, ) -> Generator[DeviceMesh, None, None]: start = time.time() - if rust: + if backend == "rs": generator = self.local_rust_device_mesh( - num_hosts, gpu_per_host, activate, controller_params=controller_params + num_hosts, gpu_per_host, controller_params=controller_params ) + elif backend == "py": + generator = self.local_py_device_mesh(num_hosts, gpu_per_host) + elif backend == "mesh": + generator = self.local_engine_on_proc_mesh(num_hosts, gpu_per_host) else: - generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate) + raise ValueError(f"invalid backend: {backend}") with generator as dm: end = time.time() logging.info("initialized mesh in {:.2f}s".format(end - start)) - yield dm + if activate: + with dm.activate(): + yield dm + else: + yield dm start = time.time() end = time.time() logging.info("shutdown mesh in {:.2f}s".format(end - start)) diff --git a/python/tests/builtins/test_log.py b/python/tests/builtins/test_log.py index 01081b06..da4c1d4d 100644 --- a/python/tests/builtins/test_log.py +++ b/python/tests/builtins/test_log.py @@ -30,7 +30,7 @@ def local_device_mesh(cls, num_hosts, gpu_per_host, backend_type, activate=True) num_hosts, gpu_per_host, activate, - rust=backend_type == BackendType.RS, + backend=str(backend_type), ) @patch("monarch.builtins.log.logger") diff --git a/python/tests/builtins/test_random.py b/python/tests/builtins/test_random.py index 24e3da38..e92a2ff3 100644 --- a/python/tests/builtins/test_random.py +++ b/python/tests/builtins/test_random.py @@ -44,7 +44,7 @@ def local_device_mesh(cls, num_hosts, gpu_per_host, backend_type, activate=True) num_hosts, gpu_per_host, activate, - rust=backend_type == BackendType.RS, + backend=str(backend_type), ) def test_set_manual_seed_remote(self, backend_type): diff --git a/python/tests/test_coalescing.py b/python/tests/test_coalescing.py index 43e5d407..86568fc4 100644 --- a/python/tests/test_coalescing.py +++ b/python/tests/test_coalescing.py @@ -78,7 +78,7 @@ def local_device_mesh( num_hosts, gpu_per_host, activate, - rust=backend_type == BackendType.RS, + backend=str(backend_type), ) @property diff --git a/python/tests/test_controller.py b/python/tests/test_controller.py index 41d6a17e..e9eb355a 100644 --- a/python/tests/test_controller.py +++ b/python/tests/test_controller.py @@ -96,7 +96,7 @@ def local_rust_device_mesh( torch.cuda.device_count() < 2, reason="Not enough GPUs, this test requires at least 2 GPUs", ) -@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS]) +@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS, "mesh"]) # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times # out is not counted as a failure, so we set a more restrictive timeout to # ensure we see a hard failure in CI. @@ -114,7 +114,7 @@ def local_device_mesh( N, gpu_per_host, activate, - rust=backend_type == BackendType.RS, + backend=str(backend_type), ) def test_errors(self, backend_type): diff --git a/python/tests/test_remote_functions.py b/python/tests/test_remote_functions.py index d6395ceb..058b8dfa 100644 --- a/python/tests/test_remote_functions.py +++ b/python/tests/test_remote_functions.py @@ -169,7 +169,7 @@ def local_device_mesh( num_hosts, gpu_per_host, activate, - rust=backend_type == BackendType.RS, + backend=str(backend_type), )