Skip to content

[4/n tensor engine] testing for tensor engine #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions monarch_extension/src/mesh_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

use std::collections::VecDeque;
use std::iter::repeat_n;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
Expand Down Expand Up @@ -49,7 +50,7 @@ use crate::convert::convert;
struct _Controller {
controller_instance: Arc<Mutex<InstanceWrapper<ControllerMessage>>>,
workers: RootActorMesh<'static, WorkerActor>,
pending_messages: Vec<PyObject>,
pending_messages: VecDeque<PyObject>,
history: history::History,
}

Expand All @@ -64,7 +65,7 @@ impl _Controller {
) -> PyResult<()> {
for (seq, response) in responses {
let message = crate::client::WorkerResponse::new(seq, response);
self.pending_messages.push(message.into_py(py));
self.pending_messages.push_back(message.into_py(py));
}
Ok(())
}
Expand All @@ -86,7 +87,7 @@ impl _Controller {
} => {
let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)?
.into_py(py);
self.pending_messages.push(dm);
self.pending_messages.push_back(dm);
}
ControllerMessage::Status {
seq,
Expand All @@ -112,15 +113,19 @@ impl _Controller {
})
}
fn send_slice(&mut self, slice: Slice, message: WorkerMessage) -> PyResult<()> {
let shape = Shape::new(
(0..slice.sizes().len()).map(|i| format!("d{i}")).collect(),
slice,
)
.unwrap();
let worker_slice = SlicedActorMesh::new(&self.workers, shape);
worker_slice
.cast(ndslice::Selection::True, message)
self.workers
.cast_slices(vec![slice], message)
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
// let shape = Shape::new(
// (0..slice.sizes().len()).map(|i| format!("d{i}")).collect(),
// slice,
// )
// .unwrap();
// println!("SENDING TO {:?} {:?}", &shape, &message);
// let worker_slice = SlicedActorMesh::new(&self.workers, shape);
// worker_slice
// .cast(ndslice::Selection::True, message)
// .map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
}
}

Expand Down Expand Up @@ -161,13 +166,17 @@ impl _Controller {
let workers = py_proc_mesh
.spawn(&format!("tensor_engine_workers_{}", id), &param)
.await?;
workers.cast(ndslice::Selection::True, AssignRankMessage::AssignRank())?;
//workers.cast(ndslice::Selection::True, )?;
workers.cast_slices(
vec![py_proc_mesh.shape().slice().clone()],
AssignRankMessage::AssignRank(),
)?;
Ok(workers)
})?;
Ok(Self {
workers: workers?,
controller_instance: Arc::new(Mutex::new(controller_instance)),
pending_messages: Vec::new(),
pending_messages: VecDeque::new(),
history: history::History::new(world_size),
})
}
Expand Down Expand Up @@ -218,7 +227,7 @@ impl _Controller {
if self.pending_messages.is_empty() {
self.fill_messages(py, timeout_msec)?;
}
Ok(self.pending_messages.pop())
Ok(self.pending_messages.pop_front())
}

fn _debugger_attach(&mut self, pdb_actor: PyActorId) -> PyResult<()> {
Expand Down Expand Up @@ -246,14 +255,14 @@ impl _Controller {
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))?;
Ok(())
}
fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<Vec<PyObject>> {
fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<()> {
self.send_slice(
self.workers.proc_mesh().shape().slice().clone(),
WorkerMessage::Exit { error: None },
)?;
let instance = self.controller_instance.clone();
let result =
signal_safe_block_on(py, async move { instance.lock().await.drain_and_stop() })??;
for r in result {
self.add_message(r)?;
}
Ok(std::mem::take(&mut self.pending_messages))
let _ = signal_safe_block_on(py, async move { instance.lock().await.drain_and_stop() })??;
Ok(())
}
}

Expand Down
17 changes: 15 additions & 2 deletions monarch_hyperactor/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ impl From<Shape> for PyShape {
frozen
)]

struct PyPoint {
pub struct PyPoint {
rank: usize,
shape: Py<PyShape>,
}

#[pymethods]
impl PyPoint {
#[new]
fn new(rank: usize, shape: Py<PyShape>) -> Self {
pub fn new(rank: usize, shape: Py<PyShape>) -> Self {
PyPoint { rank, shape }
}
fn __getitem__(&self, py: Python, label: &str) -> PyResult<usize> {
Expand All @@ -150,6 +150,19 @@ impl PyPoint {
)))
}
}

fn size(&self, py: Python<'_>, label: &str) -> PyResult<usize> {
let shape = &self.shape.bind(py).get().inner;
if let Some(index) = shape.labels().iter().position(|l| l == label) {
Ok(shape.slice().sizes()[index])
} else {
Err(PyErr::new::<PyValueError, _>(format!(
"Dimension '{}' not found",
label
)))
}
}

fn __len__(&self, py: Python) -> usize {
self.shape.bind(py).get().__len__()
}
Expand Down
1 change: 1 addition & 0 deletions monarch_tensor_worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ hyperactor = { version = "0.0.0", path = "../hyperactor" }
hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" }
itertools = "0.14.0"
monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" }
monarch_messages = { version = "0.0.0", path = "../monarch_messages" }
monarch_types = { version = "0.0.0", path = "../monarch_types" }
ndslice = { version = "0.0.0", path = "../ndslice" }
Expand Down
16 changes: 15 additions & 1 deletion monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ use hyperactor::message::Unbind;
use hyperactor::reference::ActorId;
use hyperactor_mesh::actor_mesh::Cast;
use itertools::Itertools;
use monarch_hyperactor::shape::PyPoint;
use monarch_hyperactor::shape::PyShape;
use monarch_messages::controller::ControllerActor;
use monarch_messages::controller::ControllerMessageClient;
use monarch_messages::controller::Seq;
Expand All @@ -89,6 +91,9 @@ use monarch_types::PyTree;
use ndslice::Slice;
use pipe::PipeActor;
use pipe::PipeParams;
use pyo3::Py;
use pyo3::Python;
use pyo3::types::PyAnyMethods;
use serde::Deserialize;
use serde::Serialize;
use sorted_vec::SortedVec;
Expand Down Expand Up @@ -253,10 +258,19 @@ impl Actor for WorkerActor {
impl Handler<Cast<AssignRankMessage>> for WorkerActor {
async fn handle(
&mut self,
_this: &Instance<Self>,
this: &Instance<Self>,
message: Cast<AssignRankMessage>,
) -> anyhow::Result<()> {
self.rank = message.rank.0;
Python::with_gil(|py| {
let mesh_controller = py.import_bound("monarch.mesh_controller").unwrap();
let shape: PyShape = message.shape.into();
let shape: Py<PyShape> = Py::new(py, shape).unwrap();
let p: PyPoint = PyPoint::new(message.rank.0, shape);
mesh_controller
.call_method1("_initialize_env", (p, this.proc().proc_id().to_string()))
.unwrap();
});
Ok(())
}
}
Expand Down
1 change: 1 addition & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class Point(collections.abc.Mapping):
def __new__(cls, rank: int, shape: "Shape") -> "Point": ...
def __getitem__(self, label: str) -> int: ...
def __len__(self) -> int: ...
def size(self, label: str) -> int: ...
@property
def rank(self) -> int: ...
@property
Expand Down
68 changes: 50 additions & 18 deletions python/monarch/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import tempfile
import time
from contextlib import contextmanager, ExitStack
from typing import Callable, Generator, Optional
from typing import Any, Callable, Dict, Generator, Literal, Optional

import monarch_supervisor
from monarch.common.client import Client
from monarch.common.device_mesh import DeviceMesh
from monarch.common.invocation import DeviceException, RemoteException
from monarch.common.shape import NDSlice
from monarch.controller.backend import ProcessBackend
from monarch.mesh_controller import spawn_tensor_engine
from monarch.proc_mesh import proc_mesh, ProcMesh
from monarch.python_local_mesh import PythonLocalContext
from monarch.rust_local_mesh import (
local_mesh,
Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(self):
self.cleanup = ExitStack()
self._py_process_cache = {}
self._rust_process_cache = None
self._proc_mesh_cache: Dict[Any, ProcMesh] = {}

@contextmanager
def _get_context(self, num_hosts, gpu_per_host):
Expand All @@ -75,16 +78,14 @@ def _processes(self, num_hosts, gpu_per_host):

@contextmanager
def local_py_device_mesh(
self, num_hosts, gpu_per_host, activate=True
self,
num_hosts,
gpu_per_host,
) -> Generator[DeviceMesh, None, None]:
ctx, hosts, processes = self._processes(num_hosts, gpu_per_host)
dm = world_mesh(ctx, hosts, gpu_per_host, _processes=processes)
try:
if activate:
with dm.activate():
yield dm
else:
yield dm
yield dm
dm.client.shutdown(destroy_pg=False)
except Exception:
# abnormal exit, so we just make sure we do not try to communicate in destructors,
Expand All @@ -97,7 +98,6 @@ def local_rust_device_mesh(
self,
num_hosts,
gpu_per_host,
activate: bool = True,
controller_params=None,
) -> Generator[DeviceMesh, None, None]:
# Create a new system and mesh for test.
Expand All @@ -115,11 +115,7 @@ def local_rust_device_mesh(
controller_params=controller_params,
) as dm:
try:
if activate:
with dm.activate():
yield dm
else:
yield dm
yield dm
dm.exit()
except Exception:
dm.client._shutdown = True
Expand All @@ -129,21 +125,57 @@ def local_rust_device_mesh(
# pyre-ignore: Undefined attribute
dm.client.inner._actor.stop()

@contextmanager
def local_engine_on_proc_mesh(
self,
num_hosts,
gpu_per_host,
) -> Generator[DeviceMesh, None, None]:
key = (num_hosts, gpu_per_host)
if key not in self._proc_mesh_cache:
self._proc_mesh_cache[key] = proc_mesh(
hosts=num_hosts, gpus=gpu_per_host
).get()

dm = spawn_tensor_engine(self._proc_mesh_cache[key])
dm = dm.rename(hosts="host", gpus="gpu")
try:
yield dm
dm.exit()
except Exception as e:
# abnormal exit, so we just make sure we do not try to communicate in destructors,
# but we do notn wait for workers to exit since we do not know what state they are in.
dm.client._shutdown = True
raise

@contextmanager
def local_device_mesh(
self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None
self,
num_hosts,
gpu_per_host,
activate=True,
backend: Literal["py", "rs", "mesh"] = "py",
controller_params=None,
) -> Generator[DeviceMesh, None, None]:
start = time.time()
if rust:
if backend == "rs":
generator = self.local_rust_device_mesh(
num_hosts, gpu_per_host, activate, controller_params=controller_params
num_hosts, gpu_per_host, controller_params=controller_params
)
elif backend == "py":
generator = self.local_py_device_mesh(num_hosts, gpu_per_host)
elif backend == "mesh":
generator = self.local_engine_on_proc_mesh(num_hosts, gpu_per_host)
else:
generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate)
raise ValueError(f"invalid backend: {backend}")
with generator as dm:
end = time.time()
logging.info("initialized mesh in {:.2f}s".format(end - start))
yield dm
if activate:
with dm.activate():
yield dm
else:
yield dm
start = time.time()
end = time.time()
logging.info("shutdown mesh in {:.2f}s".format(end - start))
Expand Down
16 changes: 15 additions & 1 deletion python/monarch/common/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def __init__(
# workers.
self.last_processed_seq = -1

# an error that we have received but know for certain has not
# been propagated to a future. This will be reported on shutdown
# to avoid hiding the error. This is best effort: we only keep
# the error until the point the a future is dependent on
# _any_ error, not particularly the tracked one.
self._pending_shutdown_error = None

self.recorder = Recorder()

self.pending_results: Dict[
Expand Down Expand Up @@ -174,6 +181,8 @@ def shutdown(
destroy_pg: bool = True,
error_reason: Optional[RemoteException | DeviceException | Exception] = None,
) -> None:
if self.has_shutdown:
return
logger.info("shutting down the client gracefully")

atexit.unregister(self._atexit)
Expand Down Expand Up @@ -303,6 +312,7 @@ def _handle_pending_result(self, output: MessageResult) -> None:

if error is not None:
logging.info("Received error for seq %s: %s", seq, error)
self._pending_shutdown_error = error
# We should not have set result if we have an error.
assert result is None
if not isinstance(error, RemoteException):
Expand All @@ -326,7 +336,11 @@ def _handle_pending_result(self, output: MessageResult) -> None:

fut, _ = self.pending_results[seq]
if fut is not None:
fut._set_result(result if error is None else error)
if error is None:
fut._set_result(result)
else:
fut._set_result(error)
self._pending_shutdown_error = None
elif result is not None:
logger.debug(f"{seq}: unused result {result}")
elif error is not None:
Expand Down
Loading
Loading