diff --git a/docs/operations.md b/docs/operations.md index 278a131..9647eef 100644 --- a/docs/operations.md +++ b/docs/operations.md @@ -20,7 +20,7 @@ | `bessel_i0e` | :x: | :x: | | | `bessel_i1e` | :x: | :x: | | | `betainc` | :x: | :x: | | -| `bitcase_convert_type` | :white_check_mark: | :question: | | +| `bitcast_convert_type` | :white_check_mark: | :white_check_mark: | Only same size dtype supported. | | `bitwise_not` | :white_check_mark: | :x: | | | `bitwise_and` | :white_check_mark: | :x: | | | `bitwise_or` | :white_check_mark: | :x: | | @@ -90,7 +90,7 @@ | `real` | :x: | :x: | | | `reciprocal` | :white_check_mark: | :x: | | | `reduce` | :white_check_mark: | :x: | | -| `reshape` | :x: | :x: | | +| `reshape` | :white_check_mark: | :white_check_mark: | `dimensions` argument not supported. | | `rem` | :white_check_mark: | :white_check_mark: | | | `rev` | :x: | :x: | | | `round` | :white_check_mark: | :white_check_mark: | | @@ -113,7 +113,7 @@ | `sort_key_val` | :x: | :x: | | | `sqrt` | :white_check_mark: | :white_check_mark: | | | `square` | :white_check_mark: | :x: | | -| `squeeze` | :white_check_mark: | :x: | | +| `squeeze` | :white_check_mark: | :white_check_mark: | | | `sub` | :white_check_mark: | :white_check_mark: | | | `tan` | :white_check_mark: | :white_check_mark: | | | `tie_in` | :x: | :x: | Deprecated in JAX | diff --git a/tessellate_ipu/lax/__init__.py b/tessellate_ipu/lax/__init__.py index ce4d382..fd76019 100644 --- a/tessellate_ipu/lax/__init__.py +++ b/tessellate_ipu/lax/__init__.py @@ -25,6 +25,7 @@ ) from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random +from .tile_lax_array import bitcast_convert_type_p, reshape_p from .tile_lax_binary import ( add_inplace_p, atan2_inplace_p, diff --git a/tessellate_ipu/lax/tile_lax_array.py b/tessellate_ipu/lax/tile_lax_array.py new file mode 100644 index 0000000..5433ca9 --- /dev/null +++ b/tessellate_ipu/lax/tile_lax_array.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Any, Dict, List, Tuple + +from jax.core import Primitive, ShapedArray +from jax.lax import bitcast_convert_type_p, reshape_p + +from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive + + +def ipu_reshape_primitive_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU `reshape` LAX primitive translation rule to IPU vertex. + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input shaped arrays. + attributes: (unused) attributes. + Returns: + IPU tile map primitive structure. + """ + assert len(inavals) == 1 + assert attributes is not None + inaval = inavals[0] + new_sizes = attributes["new_sizes"] + dimensions = attributes.get("dimensions", None) + if dimensions is not None: + raise NotImplementedError("TessellateIPU `reshape` does not support a custom `dimensions` argument.") + + outaval = ShapedArray(new_sizes, dtype=inaval.dtype, weak_type=inaval.dtype) + # Empty vertex name trick => identity function with inout argument, just doing reshaping. + vname = "" + inputs_info = [ + make_ipu_vertex_inout_info("x", inaval), + ] + outputs_info = [make_ipu_vertex_inout_info("x", outaval)] + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=inputs_info, + outputs_info=outputs_info, + attributes_i32=[], + attributes_f32=[], + ) + return ipu_prim_info + + +# Register JAX LAX reshape primitive. +register_ipu_tile_primitive(reshape_p, ipu_reshape_primitive_translation) + + +def ipu_bitcast_convert_type_primitive_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU `bitcast_convert_type` LAX primitive translation rule to IPU vertex. + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input shaped arrays. + attributes: (unused) attributes. + Returns: + IPU tile map primitive structure. + """ + assert len(inavals) == 1 + assert attributes is not None + inaval = inavals[0] + new_dtype = attributes["new_dtype"] + outaval = ShapedArray(inaval.shape, dtype=new_dtype, weak_type=inaval.dtype) + # Empty vertex name trick => identity function with inout argument, just doing reshaping. + vname = "" + inputs_info = [ + make_ipu_vertex_inout_info("x", inaval), + ] + outputs_info = [make_ipu_vertex_inout_info("x", outaval)] + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=inputs_info, + outputs_info=outputs_info, + attributes_i32=[], + attributes_f32=[], + ) + return ipu_prim_info + + +# Register JAX LAX bitcast_convert_type_p primitive. +register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation) diff --git a/tessellate_ipu/lib/tile_map_ops.cpp b/tessellate_ipu/lib/tile_map_ops.cpp index 6a0cc22..920d0c7 100644 --- a/tessellate_ipu/lib/tile_map_ops.cpp +++ b/tessellate_ipu/lib/tile_map_ops.cpp @@ -125,10 +125,29 @@ std::vector TileMapEquation::add( const poplar::DebugContext& debug_prefix) const { // All input tensors: i.e. add constant tensors. const auto inputs_all = this->allocateInputTensors(graph, inputs); - // No vertex => assume identity function, i.e. forward inputs. + + // No vertex => assume identity function. + // Forwarding inputs, with just potential change of shape and dtype. if (this->vname.empty()) { - return inputs_all; + // Check inputs/outputs consistent. + if (this->numInputs() != this->numOutputs()) { + throw std::logic_error( + "Inconsistent number of inputs/outputs for an identity function."); + } + // Generate output tensors (potential reshaping + change of dtype). + std::vector outputs_all; + outputs_all.reserve(inputs_all.size()); + for (size_t idx = 0; idx < inputs_all.size(); ++idx) { + const auto& in = inputs_all[idx]; + const auto& outinfo = outputs_info[idx]; + const auto outshape = shapePrependAxis(tiles.size(), outinfo.aval.shape); + const auto outdtype = toPoplar(outinfo.aval.dtype); + auto out = in.reshape(outshape).reinterpret(outdtype); + outputs_all.push_back(out); + } + return outputs_all; } + // Usual path => map a vertex. const auto outputs = this->allocateOutputTensors(graph, inputs); this->add(graph, prog, inputs_all, outputs, debug_prefix); return outputs; @@ -147,6 +166,7 @@ std::size_t TileMapEquation::numInOuts() const { throw std::logic_error( "Inconsistent number of in/outs in the IPU tile map equation."); } + // TODO: add checking on tensor size (not necessarily shape). return num_inouts0; } diff --git a/tests/lax/test_tile_lax_array.py b/tests/lax/test_tile_lax_array.py new file mode 100644 index 0000000..0cdb0ac --- /dev/null +++ b/tests/lax/test_tile_lax_array.py @@ -0,0 +1,56 @@ +# Copyright (c) 2022 Graphcore Ltd. All rights reserved. +from functools import partial + +import chex +import jax +import numpy as np +import numpy.testing as npt + +from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded +from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p + + +class IpuTileArrayPrimitiveTests(chex.TestCase): + def setUp(self): + super().setUp() + np.random.seed(42) + + def test__tile_map__reshape__ipu_jitting__proper_result(self): + tiles = (3, 4, 5) + dtype = np.float32 + inshape = (len(tiles), 6, 4) + indata = np.random.randn(*inshape).astype(dtype) + + def compute_fn(input): + input = tile_put_sharded(input, tiles) + return tile_map(reshape_p, input, new_sizes=(3, 8), dimensions=None) + + compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn) + compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) + + output_cpu = compute_fn_cpu(indata) + output_ipu = compute_fn_ipu(indata) + assert isinstance(output_ipu, TileShardedArray) + assert output_ipu.tiles == tiles + assert output_ipu.dtype == indata.dtype + npt.assert_array_equal(output_ipu, output_cpu) + + def test__tile_map__bitcast_convert_type__ipu_jitting__proper_result(self): + tiles = (3, 4, 5) + dtype = np.float32 + inshape = (len(tiles), 6, 4) + indata = np.random.randn(*inshape).astype(dtype) + + def compute_fn(input): + input = tile_put_sharded(input, tiles) + return tile_map(bitcast_convert_type_p, input, new_dtype=np.int32) + + compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn) + compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) + + output_cpu = compute_fn_cpu(indata) + output_ipu = compute_fn_ipu(indata) + assert isinstance(output_ipu, TileShardedArray) + assert output_ipu.tiles == tiles + assert output_ipu.dtype == np.int32 + npt.assert_array_equal(output_ipu, output_cpu)