Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Sep 22, 2023
1 parent 339e6e9 commit 8fe8b53
Showing 1 changed file with 97 additions and 1 deletion.
98 changes: 97 additions & 1 deletion tessellate_ipu/lax/tile_lax_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
scatter_max_p,
scatter_min_p,
scatter_mul_p,
scatter_p
)

from tessellate_ipu.core import (
Expand All @@ -23,7 +24,102 @@
from tessellate_ipu.utils import DTypeLike


def make_gather_vertex_fullname(dtype: DTypeLike) -> str:
def make_scatter_vertex_fullname(dtype: DTypeLike) -> str:
"""Generate popops Gather/MultiSlice vertex name."""
basename = "popops::MultiSlice"
return make_ipu_vertex_name_templated(basename, dtype)


def check_scatter_dimension_numbers(dimension_numbers: ScatterDimensionNumbers):
"""Check `gather` dimension_numbers is supported on TessellateIPU.
At the moment: basically only supporting a single configuration!
We need to expand on this at some point!
"""
dim_numbers_default = ScatterDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
if dimension_numbers != dim_numbers_default:
raise NotImplementedError(f"TessellateIPU `gather` only support dimension numbers: {dim_numbers_default}.")


def ipu_scatter_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `gather` primitive translation rule to IPU vertex.
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input data + start indices arrays.
attributes: Gather operator attributes
Returns:
IPU tile map primitive structure.
"""
# TODO: query for JAX device.
num_context_workers = 6

assert len(inavals) == 2
assert attributes is not None
operand, start_indices = inavals
# Extract gather attributes
dimension_numbers = attributes["dimension_numbers"]
slice_sizes = attributes["slice_sizes"]
# Default values from JAX LAX interface.
indices_are_sorted = attributes.get("indices_are_sorted", False)
unique_indices = attributes.get("unique_indices", False)
mode = attributes.get("mode", GatherScatterMode.PROMISE_IN_BOUNDS)
fill_value = attributes.get("fill_value", None)

# Check gather attributes are supported by TessellateIPU.
assert operand.ndim == 1
assert start_indices.ndim == 2
assert slice_sizes == (1,)
assert (
mode == GatherScatterMode.PROMISE_IN_BOUNDS
), "Only `PROMISE_IN_BOUNDS` gather mode supported in TessellateIPU."
assert start_indices.dtype == np.uint32, "TessellateIPU `gather` only supports `uint32` indices."
check_gather_dimension_numbers(dimension_numbers)
# Gather output aval.
outaval = p.abstract_eval(
*inavals,
dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
fill_value=fill_value,
)[0]

vname = make_gather_vertex_fullname(operand.dtype)
# Construct poplibs MultiSlice vertex attributes.
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(
baseOffset=0, # unused?
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)),
regionSize=1, # TODO: understand?
splitSingleRegion=True, # Split regions between threads?
)
# TODO: should we use `split offsets` between threads?
# For now: need to do it manually at the Python `tile_map` level.
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[make_ipu_vertex_in_info("baseT", operand), make_ipu_vertex_in_info("offsets", start_indices)],
outputs_info=[make_ipu_vertex_out_info("subT", outaval)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


# Register JAX gather primitive.
register_ipu_tile_primitive(scatter_add_p, ipu_scatter_primitive_translation)
register_ipu_tile_primitive(scatter_mul_p, ipu_scatter_primitive_translation)
register_ipu_tile_primitive(scatter_min_p, ipu_scatter_primitive_translation)
register_ipu_tile_primitive(scatter_max_p, ipu_scatter_primitive_translation)

0 comments on commit 8fe8b53

Please sign in to comment.