Skip to content

Commit

Permalink
Improve TessellateIPU gather support. (#27)
Browse files Browse the repository at this point in the history
Now fully supporting `gather` on the first axis, for any input shape.
Should allow to programmatically implement GPT-like on device embeddings using TessellateIPU.
  • Loading branch information
balancap authored Sep 26, 2023
1 parent b7bb361 commit 5b0b37b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
39 changes: 28 additions & 11 deletions tessellate_ipu/lax/tile_lax_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,25 @@ def make_gather_vertex_fullname(dtype: DTypeLike) -> str:
return make_ipu_vertex_name_templated(basename, dtype)


def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers):
def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers, inshape: Tuple[int]):
"""Check `gather` dimension_numbers is supported on TessellateIPU.
At the moment: basically only supporting a single configuration!
We need to expand on this at some point!
"""
dim_numbers_default = GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
if dimension_numbers != dim_numbers_default:
raise NotImplementedError(f"TessellateIPU `gather` only support dimension numbers: {dim_numbers_default}.")
if dimension_numbers.start_index_map != (0,):
raise NotImplementedError(
f"TessellateIPU `gather` only supports `start_index_map` (0,), not {dimension_numbers}."
)
if dimension_numbers.collapsed_slice_dims != (0,):
raise NotImplementedError(
f"TessellateIPU `gather` only supports `collapse_slice_dims` (0,), not {dimension_numbers}."
)
expected_offset_dims = tuple(range(1, len(inshape)))
if dimension_numbers.offset_dims != expected_offset_dims:
raise NotImplementedError(
f"TessellateIPU only supports `gather` on the first axis. Expecting `offset_dims` {expected_offset_dims}, not {dimension_numbers}."
)


def ipu_gather_primitive_translation(
Expand All @@ -52,7 +62,7 @@ def ipu_gather_primitive_translation(
IPU tile map primitive structure.
"""
# TODO: query for JAX device.
num_context_workers = 6
# num_context_workers = 6

assert len(inavals) == 2
assert attributes is not None
Expand All @@ -67,14 +77,20 @@ def ipu_gather_primitive_translation(
fill_value = attributes.get("fill_value", None)

# Check gather attributes are supported by TessellateIPU.
assert operand.ndim == 1
assert start_indices.ndim == 2
assert slice_sizes == (1,)
assert start_indices.ndim == 2, "Only supporting gather indices of shape (N, 1)"
assert (
mode == GatherScatterMode.PROMISE_IN_BOUNDS
), "Only `PROMISE_IN_BOUNDS` gather mode supported in TessellateIPU."
assert start_indices.dtype == np.uint32, "TessellateIPU `gather` only supports `uint32` indices."
check_gather_dimension_numbers(dimension_numbers)
check_gather_dimension_numbers(dimension_numbers, operand.shape)
# Expected slice sizes if gather on first axis.
assert operand.ndim == len(slice_sizes)
expected_slice_sizes = (1, *operand.shape[1:])
if slice_sizes != expected_slice_sizes:
raise NotImplementedError(
f"TessellateIPU only supports `gather` on the first axis, i.e. with slice sizes {expected_slice_sizes}, not {slice_sizes}."
)

# Gather output aval.
outaval = p.abstract_eval(
*inavals,
Expand All @@ -91,8 +107,9 @@ def ipu_gather_primitive_translation(
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(
baseOffset=0, # unused?
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)),
regionSize=1, # TODO: understand?
# maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)),
maxElementsPerWorker=operand.size, # Need more understanding here?
regionSize=np.prod(slice_sizes), # Total slice size.
splitSingleRegion=False, # Split regions between threads? TODO: understand!
)
# TODO: should we use `split offsets` between threads?
Expand Down
24 changes: 14 additions & 10 deletions tests/lax/test_tile_lax_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.mark.ipu_hardware
class IpuTilePrimitivesLaxGather(chex.TestCase, parameterized.TestCase):
class IpuTilePrimitivesLaxGatherHwTests(chex.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.device = jax.devices("ipu")[0]
Expand All @@ -22,18 +22,22 @@ def setUp(self):
np.random.seed(123)

@parameterized.parameters(
{"num_elements": 8, "num_indices": 3},
{"num_elements": 8, "num_indices": 12},
{"num_elements": 256, "num_indices": 512},
{"data_shape": (8,), "num_indices": 3},
{"data_shape": (8,), "num_indices": 12},
{"data_shape": (256,), "num_indices": 512},
{"data_shape": (256, 17), "num_indices": 123},
{"data_shape": (256, 5, 3), "num_indices": 373},
)
def test__tile_map__gather__jitting__proper_result(self, num_elements, num_indices):
def test__tile_map__gather__first_axis_cases__jitting__proper_result(self, data_shape, num_indices):
tiles = (0,)
data = np.random.randn(num_elements).astype(np.float32)
indices = np.random.randint(low=0, high=num_elements, size=num_indices)
data = np.random.randn(*data_shape).astype(np.float32)
indices = np.random.randint(low=0, high=data_shape[0], size=num_indices)
indices = indices.reshape(-1, 1).astype(np.uint32)

# Only supported configuration!
dim_numbers = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
# First axis gather only supported configuration!
dim_numbers = jax.lax.GatherDimensionNumbers(
offset_dims=tuple(range(1, len(data_shape))), collapsed_slice_dims=(0,), start_index_map=(0,)
)

def gather_fn(data, indices):
data = tile_put_replicated(data, tiles)
Expand All @@ -43,7 +47,7 @@ def gather_fn(data, indices):
data,
indices,
dimension_numbers=dim_numbers,
slice_sizes=(1,),
slice_sizes=(1, *data_shape[1:]),
mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS,
unique_indices=False,
indices_are_sorted=False,
Expand Down

0 comments on commit 5b0b37b

Please sign in to comment.