Skip to content

Commit

Permalink
* Make _old_env thread local so that it can be used in multiple threads.
Browse files Browse the repository at this point in the history
* Make nesting for `with mesh` work properly by using a stack.

* Allow `Mesh` to be used as a decorator:

  ```
  @pxla.Mesh(mesh_devices, ('x', 'y'))
  def dec():
    return pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)(x)
  ```

PiperOrigin-RevId: 429607613
  • Loading branch information
yashk2810 authored and jax authors committed Feb 18, 2022
1 parent 8cb1692 commit 1486be7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
24 changes: 16 additions & 8 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from __future__ import annotations

from contextlib import contextmanager
from contextlib import contextmanager, ContextDecorator
from collections import defaultdict, OrderedDict
import dataclasses
from functools import partial
Expand Down Expand Up @@ -1804,10 +1804,9 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,

# ------------------- xmap -------------------

class Mesh:
class Mesh(ContextDecorator):
devices: np.ndarray
axis_names: Tuple[MeshAxisName, ...]
_old_env: ResourceEnv

def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
assert devices.ndim == len(axis_names)
Expand Down Expand Up @@ -1836,14 +1835,14 @@ def __setattr__(self, name, value):
super().__setattr__(name, value)

def __enter__(self):
self._old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = self._old_env.with_mesh(
Mesh(self.devices, self.axis_names))
new_env = _old_env.stack[-1].with_mesh(Mesh(self.devices, self.axis_names))
_old_env.stack.append(new_env)
thread_resources.env = new_env
return thread_resources.env.physical_mesh

def __exit__(self, exc_type, exc_value, traceback):
thread_resources.env = self._old_env
del self._old_env
_old_env.stack.pop()
thread_resources.env = _old_env.stack[-1]
return False

@property
Expand Down Expand Up @@ -1982,6 +1981,15 @@ def __init__(self):

thread_resources = _ThreadResourcesLocalState()

# TODO(yashkatariya): Merge this into `_ThreadResourcesLocalState` by
# maintaining a stack there and pointing `self.env` to `self.stack[-1]`.
# Do this after the old `mesh` context manager is deprecated.
class _ThreadLocalOldEnv(threading.local):
def __init__(self):
self.stack = [EMPTY_ENV]

_old_env = _ThreadLocalOldEnv()


def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
if aval is core.abstract_unit:
Expand Down
32 changes: 32 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax.errors import JAXTypeError
from jax import lax
# TODO(skye): do we still wanna call this PartitionSpec?
from jax.experimental import maps
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh
from jax.experimental import global_device_array
Expand Down Expand Up @@ -212,6 +213,37 @@ def f(x, y):
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
check_dtypes=False)

def testDifferentNestedMesh(self):
with jtu.create_global_mesh((2, 1), ("x", "y")) as m1:
with jtu.create_global_mesh((2, 2), ("a", "b")) as m2:
self.assertEqual(pxla.thread_resources.env.physical_mesh, m2)
self.assertEqual(pxla.thread_resources.env.physical_mesh, m1)
self.assertEqual(pxla.thread_resources.env.physical_mesh,
pxla.EMPTY_ENV.physical_mesh)

def testSameNestedMesh(self):
mesh = jtu.create_global_mesh((2, 1), ("a", "b"))
with mesh as m1:
with mesh as m2:
self.assertEqual(pxla.thread_resources.env.physical_mesh, m2)
self.assertEqual(pxla.thread_resources.env.physical_mesh, m1)
self.assertEqual(pxla.thread_resources.env.physical_mesh,
pxla.EMPTY_ENV.physical_mesh)

def testMeshDecorator(self):
x = jnp.arange(8)
mesh_shape = (2, 2)
size = prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} global devices.")
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)

@maps.Mesh(mesh_devices, ('x', 'y'))
def dec():
return pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)(x)
out = dec()
self.assertArraysEqual(out, x)

@jtu.with_mesh([('x', 2), ('y', 2)])
def testTwoMeshAxisSharding(self):
@partial(pjit,
Expand Down

0 comments on commit 1486be7

Please sign in to comment.