diff --git a/tessellate_ipu/lax/tile_lax_array.py b/tessellate_ipu/lax/tile_lax_array.py index 5433ca9..6614c6a 100644 --- a/tessellate_ipu/lax/tile_lax_array.py +++ b/tessellate_ipu/lax/tile_lax_array.py @@ -1,10 +1,19 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from typing import Any, Dict, List, Tuple +import jax.lax +import numpy as np from jax.core import Primitive, ShapedArray from jax.lax import bitcast_convert_type_p, reshape_p -from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive +from tessellate_ipu.core import ( + IpuTileMapEquation, + make_ipu_vertex_attributes, + make_ipu_vertex_inout_info, + make_ipu_vertex_name_templated, + register_ipu_tile_primitive, +) +from tessellate_ipu.utils import DTypeLike def ipu_reshape_primitive_translation( @@ -95,3 +104,69 @@ def ipu_bitcast_convert_type_primitive_translation( # Register JAX LAX bitcast_convert_type_p primitive. register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation) + + +fill_p = Primitive("fill") +"""Fill primitive: create an array, and fill it with a constant. +Note: compared to `jax.lax.full`, it guarantees allocation of the full array instead of broadcasting. +""" + + +def fill(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike): + return fill_p.bind(shape=shape, fill_value=fill_value, dtype=dtype) + + +def fill_numpy_impl(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike): + return np.full(shape, fill_value, dtype=dtype) + + +def fill_abstract_eval(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike): + aval = jax.lax.full(shape, fill_value=fill_value, dtype=dtype) + return ShapedArray(aval.shape, dtype=aval.dtype) + + +def ipu_fill_primitive_translation_ipu( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU tile translation for `fill` + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input shaped arrays. + attributes: Op attributes. + Returns: + IPU tile map primitive structure. + """ + assert len(inavals) == 0 + assert attributes is not None + shape = attributes["shape"] + fill_value = attributes["fill_value"] + dtype = attributes["dtype"] + + outaval = fill_abstract_eval(shape, fill_value, dtype) + # Translation rule to IPU vertex + vname = make_ipu_vertex_name_templated("popops::Fill", outaval.dtype) + attrs_i32, attrs_f32 = make_ipu_vertex_attributes(**{"in": fill_value}) + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=[], + outputs_info=[make_ipu_vertex_inout_info("out", outaval)], + attributes_i32=attrs_i32, + attributes_f32=attrs_f32, + ) + return ipu_prim_info + + +fill_p.map_primitive = False +# Register the primal implementation with JAX +fill_p.def_impl(fill_numpy_impl) +# Register the abstract evaluation with JAX +fill_p.def_abstract_eval(fill_abstract_eval) +# Register tile IPU translation. +register_ipu_tile_primitive(fill_p, ipu_fill_primitive_translation_ipu)