From 1486be7b77d065ae16b1d56eb291d4d66786f9e5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 18 Feb 2022 11:15:56 -0800 Subject: [PATCH] * Make _old_env thread local so that it can be used in multiple threads. * Make nesting for `with mesh` work properly by using a stack. * Allow `Mesh` to be used as a decorator: ``` @pxla.Mesh(mesh_devices, ('x', 'y')) def dec(): return pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)(x) ``` PiperOrigin-RevId: 429607613 --- jax/interpreters/pxla.py | 24 ++++++++++++++++-------- tests/pjit_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 6b65095478f1..f248a541d8b9 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -30,7 +30,7 @@ from __future__ import annotations -from contextlib import contextmanager +from contextlib import contextmanager, ContextDecorator from collections import defaultdict, OrderedDict import dataclasses from functools import partial @@ -1804,10 +1804,9 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, # ------------------- xmap ------------------- -class Mesh: +class Mesh(ContextDecorator): devices: np.ndarray axis_names: Tuple[MeshAxisName, ...] - _old_env: ResourceEnv def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]): assert devices.ndim == len(axis_names) @@ -1836,14 +1835,14 @@ def __setattr__(self, name, value): super().__setattr__(name, value) def __enter__(self): - self._old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) - thread_resources.env = self._old_env.with_mesh( - Mesh(self.devices, self.axis_names)) + new_env = _old_env.stack[-1].with_mesh(Mesh(self.devices, self.axis_names)) + _old_env.stack.append(new_env) + thread_resources.env = new_env return thread_resources.env.physical_mesh def __exit__(self, exc_type, exc_value, traceback): - thread_resources.env = self._old_env - del self._old_env + _old_env.stack.pop() + thread_resources.env = _old_env.stack[-1] return False @property @@ -1982,6 +1981,15 @@ def __init__(self): thread_resources = _ThreadResourcesLocalState() +# TODO(yashkatariya): Merge this into `_ThreadResourcesLocalState` by +# maintaining a stack there and pointing `self.env` to `self.stack[-1]`. +# Do this after the old `mesh` context manager is deprecated. +class _ThreadLocalOldEnv(threading.local): + def __init__(self): + self.stack = [EMPTY_ENV] + +_old_env = _ThreadLocalOldEnv() + def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval): if aval is core.abstract_unit: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b05bce60e915..9ce6c573eea7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -29,6 +29,7 @@ from jax.errors import JAXTypeError from jax import lax # TODO(skye): do we still wanna call this PartitionSpec? +from jax.experimental import maps from jax.experimental import PartitionSpec as P from jax.experimental.maps import xmap, mesh from jax.experimental import global_device_array @@ -212,6 +213,37 @@ def f(x, y): self.assertAllClose(actual.device_buffers[3].to_py(), split1, check_dtypes=False) + def testDifferentNestedMesh(self): + with jtu.create_global_mesh((2, 1), ("x", "y")) as m1: + with jtu.create_global_mesh((2, 2), ("a", "b")) as m2: + self.assertEqual(pxla.thread_resources.env.physical_mesh, m2) + self.assertEqual(pxla.thread_resources.env.physical_mesh, m1) + self.assertEqual(pxla.thread_resources.env.physical_mesh, + pxla.EMPTY_ENV.physical_mesh) + + def testSameNestedMesh(self): + mesh = jtu.create_global_mesh((2, 1), ("a", "b")) + with mesh as m1: + with mesh as m2: + self.assertEqual(pxla.thread_resources.env.physical_mesh, m2) + self.assertEqual(pxla.thread_resources.env.physical_mesh, m1) + self.assertEqual(pxla.thread_resources.env.physical_mesh, + pxla.EMPTY_ENV.physical_mesh) + + def testMeshDecorator(self): + x = jnp.arange(8) + mesh_shape = (2, 2) + size = prod(mesh_shape) + if len(jax.devices()) < size: + raise unittest.SkipTest(f"Test requires {size} global devices.") + mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape) + + @maps.Mesh(mesh_devices, ('x', 'y')) + def dec(): + return pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)(x) + out = dec() + self.assertArraysEqual(out, x) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testTwoMeshAxisSharding(self): @partial(pjit,