Skip to content

Commit

Permalink
Add fill TessellateIPU primitive mapping popops vertex.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Sep 28, 2023
1 parent 1cfe2ea commit 91a5547
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple

import jax.lax
import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import bitcast_convert_type_p, reshape_p

from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive
from tessellate_ipu.core import (
IpuTileMapEquation,
make_ipu_vertex_attributes,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
register_ipu_tile_primitive,
)
from tessellate_ipu.utils import DTypeLike


def ipu_reshape_primitive_translation(
Expand Down Expand Up @@ -95,3 +104,69 @@ def ipu_bitcast_convert_type_primitive_translation(

# Register JAX LAX bitcast_convert_type_p primitive.
register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation)


fill_p = Primitive("fill")
"""Fill primitive: create an array, and fill it with a constant.
Note: compared to `jax.lax.full`, it guarantees allocation of the full array instead of broadcasting.
"""


def fill(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
return fill_p.bind(shape=shape, fill_value=fill_value, dtype=dtype)


def fill_numpy_impl(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
return np.full(shape, fill_value, dtype=dtype)


def fill_abstract_eval(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
aval = jax.lax.full(shape, fill_value=fill_value, dtype=dtype)
return ShapedArray(aval.shape, dtype=aval.dtype)


def ipu_fill_primitive_translation_ipu(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU tile translation for `fill`
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: Op attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 0
assert attributes is not None
shape = attributes["shape"]
fill_value = attributes["fill_value"]
dtype = attributes["dtype"]

outaval = fill_abstract_eval(shape, fill_value, dtype)
# Translation rule to IPU vertex
vname = make_ipu_vertex_name_templated("popops::Fill", outaval.dtype)
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(**{"in": fill_value})
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[],
outputs_info=[make_ipu_vertex_inout_info("out", outaval)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


fill_p.map_primitive = False
# Register the primal implementation with JAX
fill_p.def_impl(fill_numpy_impl)
# Register the abstract evaluation with JAX
fill_p.def_abstract_eval(fill_abstract_eval)
# Register tile IPU translation.
register_ipu_tile_primitive(fill_p, ipu_fill_primitive_translation_ipu)

0 comments on commit 91a5547

Please sign in to comment.