Skip to content

Commit

Permalink
Add fill TessellateIPU primitive mapping popops vertex. (#38)
Browse files Browse the repository at this point in the history
Compared to the `full` LAX operation, the `fill_p` TessellateIPU primitive
ensures that the full array is allocated and filled with a given value (and not broadcasted).

As an useful example of using `fill`, in combination with `scatter`, we implement
`tile_sharded_identity` which creates an identity matrix, with proper tile mapping, without
introducing additional communications.
  • Loading branch information
balancap authored Sep 29, 2023
1 parent 40572ba commit af0cc1f
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random
from .tile_lax_array import bitcast_convert_type_p, reshape_p
from .tile_lax_array import bitcast_convert_type_p, fill, fill_p, reshape_p, tile_fill, tile_sharded_identity
from .tile_lax_binary import (
add_inplace_p,
atan2_inplace_p,
Expand Down
146 changes: 143 additions & 3 deletions tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Sequence, Tuple, Union

import jax.lax
import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import bitcast_convert_type_p, reshape_p
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir
from jax.lax import bitcast_convert_type_p, reshape_p, scatter_p

from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive
from tessellate_ipu.core import (
IpuTileMapEquation,
TileShardedArray,
make_ipu_vertex_attributes,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
tile_constant_replicated,
tile_constant_sharded,
tile_map,
)
from tessellate_ipu.utils import DTypeLike


def ipu_reshape_primitive_translation(
Expand Down Expand Up @@ -95,3 +111,127 @@ def ipu_bitcast_convert_type_primitive_translation(

# Register JAX LAX bitcast_convert_type_p primitive.
register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation)


fill_p = Primitive("fill")
"""Fill primitive: create an array, and fill it with a constant.
Note: compared to `jax.lax.full`, it guarantees allocation of the full array instead of broadcasting.
"""


def fill(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
"""Fill a tensor with given shape and value."""
return fill_p.bind(shape=shape, fill_value=fill_value, dtype=dtype)


def fill_numpy_impl(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
return np.full(shape, fill_value, dtype=dtype)


def fill_abstract_eval(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
aval = jax.lax.full(shape, fill_value=fill_value, dtype=dtype)
return ShapedArray(aval.shape, dtype=aval.dtype)


def ipu_fill_primitive_translation_ipu(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU tile translation for `fill`
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: Op attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 0
assert attributes is not None
shape = attributes["shape"]
fill_value = attributes["fill_value"]
dtype = attributes["dtype"]

outaval = fill_abstract_eval(shape, fill_value, dtype)
# Translation rule to IPU vertex
vname = make_ipu_vertex_name_templated("popops::Fill", outaval.dtype)
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(**{"in": fill_value})
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[],
outputs_info=[make_ipu_vertex_out_info("out", outaval)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


def fill_mlir_translation_default(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
"""`fill` default MLIR translation, for CPU/GPU/IPU/... backends."""
outaval = ctx.avals_out[0]
fill_value = params["fill_value"]

def fill_fn(*inputs):
return jax.lax.full(outaval.shape, fill_value, outaval.dtype)

# Lower to MLIR using JAX tooling. TODO: cache lowering?
fill_lower_fn = mlir.lower_fun(fill_fn, multiple_results=False)
return fill_lower_fn(ctx, *args)


fill_p.map_primitive = False
# Register the primal implementation with JAX.
fill_p.def_impl(fill_numpy_impl)
# Register the abstract evaluation with JAX.
fill_p.def_abstract_eval(fill_abstract_eval)
# Default MLIR translation for all backends.
mlir.register_lowering(fill_p, fill_mlir_translation_default)
# Register TessellateIPU translation.
register_ipu_tile_primitive(fill_p, ipu_fill_primitive_translation_ipu)


def tile_fill(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike, tiles: Tuple[int, ...]) -> TileShardedArray:
"""Tile `fill` a tensor with given shape and value."""
return tile_map(fill_p, shape=shape, fill_value=fill_value, dtype=dtype, tiles=tiles) # type:ignore


def tile_sharded_identity(dtype: DTypeLike, tiles: Tuple[int, ...]) -> TileShardedArray:
"""Create a tile sharded identity matrix, i.e. sharded on tiles across the first axis.
Args:
dtype: Dtype of the identity matrix.
tiles: Sharding tiles.
Returns:
Sharded identity matrix (N, N), with N = len(tiles)
"""
with jax.named_scope("tile_sharded_identity"):
N = len(tiles)
# Build zero matrix + update diagonal entries.
arr = tile_fill((N,), 0, dtype=dtype, tiles=tiles)
# Requiring constants for indices + updates. Something more efficient?s
indices = tile_constant_sharded(np.arange(0, N, dtype=np.uint32).reshape(N, 1, 1), tiles=tiles)
updates = tile_constant_replicated(np.array([1], dtype=dtype), tiles=tiles)
# Not the simplest way ever of updating diagonal terms!
scatter_dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)
)
arr = tile_map(
scatter_p,
arr,
indices,
updates,
dimension_numbers=scatter_dnums,
indices_are_sorted=False,
unique_indices=False,
mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS,
update_jaxpr=None,
update_consts=None,
) # type:ignore
return arr
54 changes: 53 additions & 1 deletion tests/lax/test_tile_lax_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p
from tessellate_ipu.lax import bitcast_convert_type_p, fill, reshape_p, tile_fill, tile_sharded_identity


class IpuTileArrayPrimitiveTests(chex.TestCase):
Expand Down Expand Up @@ -54,3 +55,54 @@ def compute_fn(input):
assert output_ipu.tiles == tiles
assert output_ipu.dtype == np.int32
npt.assert_array_equal(output_ipu, output_cpu)

@parameterized.parameters(
[
(np.int32,),
(np.float16,),
(np.float32,),
]
)
def test__tile_map__fill__ipu_jitting__proper_result(self, dtype):
tiles = (3, 1, 5)
shape = (4, 5)
fill_value = 1

def compute_fn():
return tile_fill(shape, fill_value, dtype=dtype, tiles=tiles)

compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)
output_ipu = compute_fn_ipu()
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == dtype
npt.assert_array_equal(output_ipu, np.full((len(tiles), *shape), fill_value, dtype=dtype))

def test__tile_map__fill__cpu_jitting__proper_result(self):
shape = (4, 5)
fill_value = 2

def compute_fn():
return fill(shape, fill_value, np.float32)

fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
output_cpu = fn_cpu()
assert output_cpu.dtype == np.float32
npt.assert_array_equal(output_cpu, np.full(shape, fill_value, dtype=np.float32))

def test__tile_sharded_identity__ipu_jitting__proper_result(self):
dtype = np.float32
tiles = (1, 2, 5)
N = len(tiles)

def fn():
# Comparison point with the "obvious" way using JAX Numpy.
# return tile_put_sharded(jax.numpy.identity(N, dtype), tiles=tiles)
return tile_sharded_identity(dtype, tiles)

fn_ipu = partial(jax.jit, backend="ipu")(fn)
output_ipu = fn_ipu()
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == dtype
npt.assert_array_equal(output_ipu, np.identity(N, dtype=dtype))

0 comments on commit af0cc1f

Please sign in to comment.