Skip to content

Commit

Permalink
Add basic support for shard_map to oryx
Browse files Browse the repository at this point in the history
This doesn't allow sowing inside shard map, but does allow sowing if it occurs before or after the shard map. It adds tests for all these three cases, ensuring errors are thrown if sows are inside shard_map

PiperOrigin-RevId: 667331298
  • Loading branch information
The oryx Authors committed Aug 28, 2024
1 parent 21722fa commit 7f35db9
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
19 changes: 19 additions & 0 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def f(x):
from jax._src import pjit
from jax._src import sharding_impls
from jax._src.lax import control_flow as lcf
from jax.experimental import shard_map
import jax.extend.linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
Expand Down Expand Up @@ -443,6 +444,11 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
return context.process_custom_jvp_call(self, primitive, fun, jvp, tracers,
symbolic_zeros=symbolic_zeros)

def process_shard_map(self, primitive, f, tracers, **params):
out_flat = primitive.bind(f, *[t.val for t in tracers], **params)
out_tracers = map(self.pure, out_flat)
return out_tracers

def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
context = trace_util.get_dynamic_context(self)
return context.post_process_custom_jvp_call(self, out_tracers, jvp_was_run)
Expand Down Expand Up @@ -1704,3 +1710,16 @@ def harvest(f,
kwargs = dict(
tag=tag, allowlist=allowlist, blocklist=blocklist, exclusive=exclusive)
return call_and_reap(plant(f, **kwargs), **kwargs)


# Handle shard_map
@shard_map.register_check(sow_p)
def _sow_check(mesh, *in_rep, name, tag, mode, tree):
del mesh, name, tag, mode, tree
return in_rep[0] # TODO(conmy): does this limit use to one output only?


@shard_map.register_rewrite(sow_p)
def _sow_rewrite(mesh, in_rep, *args, name, tag, mode, tree):
raise ValueError('Detected sow calls inside a shard_map.'
' This is not currently supported.')
143 changes: 142 additions & 1 deletion oryx/core/interpreters/harvest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from jax import config
from jax import lax
from jax._src import pjit
from jax.experimental import mesh_utils
from jax.experimental import shard_map
import jax.numpy as jnp
import numpy as np

from oryx.core import trace_util
from oryx.core.interpreters import harvest
from oryx.internal import test_util
Expand Down Expand Up @@ -1019,6 +1020,146 @@ def branch3(x):
self.assertEqual(out, 8.)


class ShardMapTest(test_util.TestCase):

def setUp(self):
super().setUp()
self.devices = mesh_utils.create_device_mesh((1, 2))
self.mesh = jax.sharding.Mesh(self.devices, axis_names=('x', 'y'))
self.a = jnp.arange(8 * 16.0).reshape(8, 16)
self.b = jnp.arange(16 * 4.0).reshape(16, 4)

def _f(a, b):
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(jax.sharding.PartitionSpec('x', 'y'),
jax.sharding.PartitionSpec('y', None)),
out_specs=jax.sharding.PartitionSpec('x', None),
)
def oryx_shmap_matmul(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block

shmapped_val = oryx_shmap_matmul(a, b)
sowed_val = sow(shmapped_val, name='shmapped_val', tag='intermediate')
return 2.0 * sowed_val

self.f = _f

def _f_with_sow_before_shmap(a, b):
sowed_val = sow(a, name='a', tag='intermediate')

@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(jax.sharding.PartitionSpec('x', 'y'),
jax.sharding.PartitionSpec('y', None)),
out_specs=jax.sharding.PartitionSpec('x', None),
)
def oryx_shmap_matmul(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partial_sum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partial_sum, 'y')
# c_block: f32[2, 4]
return c_block

shmapped_val = oryx_shmap_matmul(sowed_val, b)
return 2.0 * shmapped_val

self.f_with_sow_before_shmap = _f_with_sow_before_shmap

def _f_with_sow_inside_shmap(a, b):
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(jax.sharding.PartitionSpec('x', 'y'),
jax.sharding.PartitionSpec('y', None)),
out_specs=jax.sharding.PartitionSpec('x', None),
)
def oryx_shmap_matmul(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partial_sum = jnp.dot(a_block, b_block)
c_block = sow(
jax.lax.psum(c_partial_sum, 'y'), name='c_block', tag='intermediate'
)
# c_block: f32[2, 4]
return c_block

return 2.0 * oryx_shmap_matmul(a, b)

self.f_with_sow_inside_shmap = _f_with_sow_inside_shmap

def test_reap(self):
reap_dict = reap(self.f, tag='intermediate')(self.a, self.b)
self.assertEqual(
list(reap_dict.keys()), ['shmapped_val'], msg='Wrong reap dict keys'
)

self.assertFalse(
np.isclose(
reap_dict['shmapped_val'], 2.0 * jnp.dot(self.a, self.b)
).any(),
msg=(
'Reaped value is close to 2.0 * matmul but that'
' should be the output of the function, not the'
' intermediate reaped value.'
),
)
np.testing.assert_allclose(
reap_dict['shmapped_val'], jnp.dot(self.a, self.b)
)

def test_plant(self):
shampped_val_for_planting = 0.5 * jnp.dot(self.a, self.b)
f_output_planted = plant(self.f, tag='intermediate')(
dict(shmapped_val=shampped_val_for_planting), self.a, self.b
)
np.testing.assert_allclose(
f_output_planted, 2.0 * shampped_val_for_planting
)

def test_reap_before_shmap(self):
reap_dict = reap(self.f_with_sow_before_shmap, tag='intermediate')(
self.a, self.b
)
self.assertEqual(list(reap_dict.keys()), ['a'], msg='Wrong reap dict keys')
np.testing.assert_allclose(reap_dict['a'], self.a)

def test_plant_before_shmap(self):
a_val_for_planting = 0.5 * self.a
f_output_planted = plant(self.f_with_sow_before_shmap, tag='intermediate')(
dict(a=a_val_for_planting), self.a, self.b
)
np.testing.assert_allclose(
f_output_planted, 2.0 * jnp.dot(a_val_for_planting, self.b)
)

def test_reap_inside_shmap_fails(self):
with self.assertRaisesRegex(
ValueError,
'Detected sow calls inside a shard_map.'
' This is not currently supported.',
):
reap(self.f_with_sow_inside_shmap, tag='intermediate')(self.a, self.b)

def test_plant_inside_shmap_fails(self):
with self.assertRaisesRegex(
ValueError,
'Detected sow calls inside a shard_map.'
' This is not currently supported.',
):
plant(self.f_with_sow_inside_shmap, tag='intermediate')(
dict(c_block=15.0 * jnp.dot(self.a, self.b)), self.a, self.b
)


if __name__ == '__main__':
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
absltest.main()

0 comments on commit 7f35db9

Please sign in to comment.