From 7f35db97611f3e918336579176dc3805e59cc347 Mon Sep 17 00:00:00 2001 From: The oryx Authors Date: Sun, 25 Aug 2024 09:05:18 -0700 Subject: [PATCH] Add basic support for shard_map to oryx This doesn't allow sowing inside shard map, but does allow sowing if it occurs before or after the shard map. It adds tests for all these three cases, ensuring errors are thrown if sows are inside shard_map PiperOrigin-RevId: 667331298 --- oryx/core/interpreters/harvest.py | 19 ++++ oryx/core/interpreters/harvest_test.py | 143 ++++++++++++++++++++++++- 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index f787ce5..aca9273 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -156,6 +156,7 @@ def f(x): from jax._src import pjit from jax._src import sharding_impls from jax._src.lax import control_flow as lcf +from jax.experimental import shard_map import jax.extend.linear_util as lu from jax.interpreters import ad from jax.interpreters import batching @@ -443,6 +444,11 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, return context.process_custom_jvp_call(self, primitive, fun, jvp, tracers, symbolic_zeros=symbolic_zeros) + def process_shard_map(self, primitive, f, tracers, **params): + out_flat = primitive.bind(f, *[t.val for t in tracers], **params) + out_tracers = map(self.pure, out_flat) + return out_tracers + def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): context = trace_util.get_dynamic_context(self) return context.post_process_custom_jvp_call(self, out_tracers, jvp_was_run) @@ -1704,3 +1710,16 @@ def harvest(f, kwargs = dict( tag=tag, allowlist=allowlist, blocklist=blocklist, exclusive=exclusive) return call_and_reap(plant(f, **kwargs), **kwargs) + + +# Handle shard_map +@shard_map.register_check(sow_p) +def _sow_check(mesh, *in_rep, name, tag, mode, tree): + del mesh, name, tag, mode, tree + return in_rep[0] # TODO(conmy): does this limit use to one output only? + + +@shard_map.register_rewrite(sow_p) +def _sow_rewrite(mesh, in_rep, *args, name, tag, mode, tree): + raise ValueError('Detected sow calls inside a shard_map.' + ' This is not currently supported.') diff --git a/oryx/core/interpreters/harvest_test.py b/oryx/core/interpreters/harvest_test.py index f020a5a..05b0e08 100644 --- a/oryx/core/interpreters/harvest_test.py +++ b/oryx/core/interpreters/harvest_test.py @@ -38,9 +38,10 @@ from jax import config from jax import lax from jax._src import pjit +from jax.experimental import mesh_utils +from jax.experimental import shard_map import jax.numpy as jnp import numpy as np - from oryx.core import trace_util from oryx.core.interpreters import harvest from oryx.internal import test_util @@ -1019,6 +1020,146 @@ def branch3(x): self.assertEqual(out, 8.) +class ShardMapTest(test_util.TestCase): + + def setUp(self): + super().setUp() + self.devices = mesh_utils.create_device_mesh((1, 2)) + self.mesh = jax.sharding.Mesh(self.devices, axis_names=('x', 'y')) + self.a = jnp.arange(8 * 16.0).reshape(8, 16) + self.b = jnp.arange(16 * 4.0).reshape(16, 4) + + def _f(a, b): + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=(jax.sharding.PartitionSpec('x', 'y'), + jax.sharding.PartitionSpec('y', None)), + out_specs=jax.sharding.PartitionSpec('x', None), + ) + def oryx_shmap_matmul(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partialsum = jnp.dot(a_block, b_block) + c_block = jax.lax.psum(c_partialsum, 'y') + # c_block: f32[2, 4] + return c_block + + shmapped_val = oryx_shmap_matmul(a, b) + sowed_val = sow(shmapped_val, name='shmapped_val', tag='intermediate') + return 2.0 * sowed_val + + self.f = _f + + def _f_with_sow_before_shmap(a, b): + sowed_val = sow(a, name='a', tag='intermediate') + + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=(jax.sharding.PartitionSpec('x', 'y'), + jax.sharding.PartitionSpec('y', None)), + out_specs=jax.sharding.PartitionSpec('x', None), + ) + def oryx_shmap_matmul(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partial_sum = jnp.dot(a_block, b_block) + c_block = jax.lax.psum(c_partial_sum, 'y') + # c_block: f32[2, 4] + return c_block + + shmapped_val = oryx_shmap_matmul(sowed_val, b) + return 2.0 * shmapped_val + + self.f_with_sow_before_shmap = _f_with_sow_before_shmap + + def _f_with_sow_inside_shmap(a, b): + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=(jax.sharding.PartitionSpec('x', 'y'), + jax.sharding.PartitionSpec('y', None)), + out_specs=jax.sharding.PartitionSpec('x', None), + ) + def oryx_shmap_matmul(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partial_sum = jnp.dot(a_block, b_block) + c_block = sow( + jax.lax.psum(c_partial_sum, 'y'), name='c_block', tag='intermediate' + ) + # c_block: f32[2, 4] + return c_block + + return 2.0 * oryx_shmap_matmul(a, b) + + self.f_with_sow_inside_shmap = _f_with_sow_inside_shmap + + def test_reap(self): + reap_dict = reap(self.f, tag='intermediate')(self.a, self.b) + self.assertEqual( + list(reap_dict.keys()), ['shmapped_val'], msg='Wrong reap dict keys' + ) + + self.assertFalse( + np.isclose( + reap_dict['shmapped_val'], 2.0 * jnp.dot(self.a, self.b) + ).any(), + msg=( + 'Reaped value is close to 2.0 * matmul but that' + ' should be the output of the function, not the' + ' intermediate reaped value.' + ), + ) + np.testing.assert_allclose( + reap_dict['shmapped_val'], jnp.dot(self.a, self.b) + ) + + def test_plant(self): + shampped_val_for_planting = 0.5 * jnp.dot(self.a, self.b) + f_output_planted = plant(self.f, tag='intermediate')( + dict(shmapped_val=shampped_val_for_planting), self.a, self.b + ) + np.testing.assert_allclose( + f_output_planted, 2.0 * shampped_val_for_planting + ) + + def test_reap_before_shmap(self): + reap_dict = reap(self.f_with_sow_before_shmap, tag='intermediate')( + self.a, self.b + ) + self.assertEqual(list(reap_dict.keys()), ['a'], msg='Wrong reap dict keys') + np.testing.assert_allclose(reap_dict['a'], self.a) + + def test_plant_before_shmap(self): + a_val_for_planting = 0.5 * self.a + f_output_planted = plant(self.f_with_sow_before_shmap, tag='intermediate')( + dict(a=a_val_for_planting), self.a, self.b + ) + np.testing.assert_allclose( + f_output_planted, 2.0 * jnp.dot(a_val_for_planting, self.b) + ) + + def test_reap_inside_shmap_fails(self): + with self.assertRaisesRegex( + ValueError, + 'Detected sow calls inside a shard_map.' + ' This is not currently supported.', + ): + reap(self.f_with_sow_inside_shmap, tag='intermediate')(self.a, self.b) + + def test_plant_inside_shmap_fails(self): + with self.assertRaisesRegex( + ValueError, + 'Detected sow calls inside a shard_map.' + ' This is not currently supported.', + ): + plant(self.f_with_sow_inside_shmap, tag='intermediate')( + dict(c_block=15.0 * jnp.dot(self.a, self.b)), self.a, self.b + ) + + if __name__ == '__main__': os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' absltest.main()