From d5dfd3cc58d891bdecfa8221238dbe4b565f1551 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 28 Sep 2023 14:46:44 +0100 Subject: [PATCH] Implement and integrate cumsum/max/min/prod TessellateIPU vertex. Only supporting 1D input arrays at the moment. Note: performance is not yet optimal, as using a single thread + no vectorized instructions. --- docs/operations.md | 14 ++-- .../vertex/tile_cumulative_ops_vertex.cpp | 72 +++++++++++++++++ tessellate_ipu/lax/__init__.py | 1 + tessellate_ipu/lax/tile_lax_cumulative_ops.py | 78 +++++++++++++++++++ tests/lax/test_tile_lax_cumulative_ops.py | 47 +++++++++++ 5 files changed, 205 insertions(+), 7 deletions(-) create mode 100644 tessellate_ipu/core/vertex/tile_cumulative_ops_vertex.cpp create mode 100644 tessellate_ipu/lax/tile_lax_cumulative_ops.py create mode 100644 tests/lax/test_tile_lax_cumulative_ops.py diff --git a/docs/operations.md b/docs/operations.md index 3b040e4..90ac0f7 100644 --- a/docs/operations.md +++ b/docs/operations.md @@ -32,19 +32,19 @@ | `ceil` | :white_check_mark: | :white_check_mark: | | | `clamp` | :white_check_mark: | :x: | | | `collapse` | :white_check_mark: | :white_check_mark: | | -| `complex` | :x: | n/a | Complex not supported in IPU XLA backend | +| `complex` | :x: | n/a | `complex` dtype not supported in IPU XLA backend | | `concatenate` | :white_check_mark: | n/a | | -| `conj` | :x: | n/a | Complex not supported in IPU XLA backend | +| `conj` | :x: | n/a | `complex` dtype not supported in IPU XLA backend | | `conv` | :x: | n/a | | | `convert_element_type` | :white_check_mark: | :x: | | | `conv_general_dilated` | :x: | n/a | | | `conv_transpose` | :x: | n/a | | | `cos` | :white_check_mark: | :white_check_mark: | | | `cosh` | :white_check_mark: | :white_check_mark: | | -| `cummax` | :x: | :x: | | -| `cummin` | :x: | :x: | | -| `cumprod` | :x: | :x: | | -| `cumsum` | :x: | :x: | | +| `cummax` | :white_check_mark: | :x: | Only supporting 1D input tensor | +| `cummin` | :white_check_mark: | :x: | Only supporting 1D input tensor | +| `cumprod` | :white_check_mark: | :x: | Only supporting 1D input tensor | +| `cumsum` | :white_check_mark: | :x: | Only supporting 1D input tensor | | `digamma` | :x: | :x: | | | `div` | :white_check_mark: | :white_check_mark: | | | `dot` | :white_check_mark: | n/a | | @@ -67,7 +67,7 @@ | `gt` | :white_check_mark: | n/a | | | `igamma` | :x: | :x: | | | `igammac` | :x: | :x: | | -| `imag` | :x: | :x: | | +| `imag` | :x: | :x: | `complex` dtype not supported in IPU XLA backend | | `index_in_dim` | :x: | n/a | | | `index_take` | :x: | n/a | | | `iota` | :white_check_mark: | n/a | | diff --git a/tessellate_ipu/core/vertex/tile_cumulative_ops_vertex.cpp b/tessellate_ipu/core/vertex/tile_cumulative_ops_vertex.cpp new file mode 100644 index 0000000..9bac724 --- /dev/null +++ b/tessellate_ipu/core/vertex/tile_cumulative_ops_vertex.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2022 Graphcore Ltd. All rights reserved. +#include +#include +#include + +#include "intrinsics_utils.hpp" +using namespace poplar; + +namespace tl { + +enum OpType : int { SUM = 0, MIN = 1, MAX = 2, PROD = 3 }; + +template +T initial_accumulator_value() { + if constexpr (OP == SUM) { + return T(0); + } else if constexpr (OP == MIN) { + return std::numeric_limits::max(); + } else if constexpr (OP == MAX) { + return std::numeric_limits::lowest(); + } else { + return T(1); + } +} +template +T cumulative_op(T acc, T rhs) { + if constexpr (OP == SUM) { + return acc + rhs; + } else if constexpr (OP == MIN) { + return std::min(acc, rhs); + } else if constexpr (OP == MAX) { + return std::max(acc, rhs); + } else { + return acc * rhs; + } +} + +/** + * @brief Cumulative op vertex. + * Very simple implementation at first, no big optimization! + */ +template +class CumulativeOp : public Vertex { + public: + Input> in; + Output> out; + + bool compute() { + T accumulator = initial_accumulator_value(); + const int32_t size = in.size(); + for (int32_t idx = 0; idx < size; ++idx) { + accumulator = cumulative_op(accumulator, in[idx]); + out[idx] = accumulator; + } + return true; + } +}; + +// explicit instantiations +template class CumulativeOp; +template class CumulativeOp; + +template class CumulativeOp; +template class CumulativeOp; + +template class CumulativeOp; +template class CumulativeOp; + +template class CumulativeOp; +template class CumulativeOp; + +} // namespace tl diff --git a/tessellate_ipu/lax/__init__.py b/tessellate_ipu/lax/__init__.py index fd76019..58dbe27 100644 --- a/tessellate_ipu/lax/__init__.py +++ b/tessellate_ipu/lax/__init__.py @@ -39,6 +39,7 @@ scaled_sub_p, sub_inplace_p, ) +from .tile_lax_cumulative_ops import cummax_p, cummin_p, cumprod_p, cumsum_p from .tile_lax_dot import IpuConvVertexType from .tile_lax_gather import gather_p from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p diff --git a/tessellate_ipu/lax/tile_lax_cumulative_ops.py b/tessellate_ipu/lax/tile_lax_cumulative_ops.py new file mode 100644 index 0000000..1bdd233 --- /dev/null +++ b/tessellate_ipu/lax/tile_lax_cumulative_ops.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import os +from typing import Any, Dict, List, Tuple + +from jax.core import Primitive, ShapedArray +from jax.lax import cummax_p, cummin_p, cumprod_p, cumsum_p + +from tessellate_ipu.core import ( + IpuTileMapEquation, + make_ipu_vertex_in_info, + make_ipu_vertex_name_templated, + make_ipu_vertex_out_info, + register_ipu_tile_primitive, +) + +_cumop_primitive_to_opcode: Dict[Primitive, int] = { + cumsum_p: 0, + cummin_p: 1, + cummax_p: 2, + cumprod_p: 3, +} + + +def get_cumulative_ops_gp_filename() -> str: + return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_cumulative_ops_vertex.cpp") + + +def ipu_cumop_primitive_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU cumulative ops LAX primitive translation rule to IPU vertex. + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input shaped arrays. + attributes: (unused) attributes. + Returns: + IPU tile map primitive structure. + """ + assert len(inavals) == 1 + assert attributes is not None + inaval = inavals[0] + axis = attributes.get("axis", 0) + reverse = attributes.get("reverse", False) + + # Subset of configuration supported. + assert axis == 0 + assert not reverse + assert len(inaval.shape) == 1 + outaval = inaval + # TessellateIPU custom cumulative vertices. + opcode = _cumop_primitive_to_opcode[p] + vname = make_ipu_vertex_name_templated("tl::CumulativeOp", inaval.dtype, opcode) + inputs_info = [make_ipu_vertex_in_info("in", inaval)] + outputs_info = [make_ipu_vertex_out_info("out", outaval)] + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=inputs_info, + outputs_info=outputs_info, + attributes_i32=[], + attributes_f32=[], + gp_filename=get_cumulative_ops_gp_filename(), + perf_estimate=inaval.size * 6, + ) + return ipu_prim_info + + +# Register JAX LAX cumulative-ops primitive. +register_ipu_tile_primitive(cumsum_p, ipu_cumop_primitive_translation) +register_ipu_tile_primitive(cummax_p, ipu_cumop_primitive_translation) +register_ipu_tile_primitive(cummin_p, ipu_cumop_primitive_translation) +register_ipu_tile_primitive(cumprod_p, ipu_cumop_primitive_translation) diff --git a/tests/lax/test_tile_lax_cumulative_ops.py b/tests/lax/test_tile_lax_cumulative_ops.py new file mode 100644 index 0000000..649da3d --- /dev/null +++ b/tests/lax/test_tile_lax_cumulative_ops.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial + +import chex +import jax +import numpy as np +import numpy.testing as npt +from absl.testing import parameterized + +from tessellate_ipu import tile_map, tile_put_replicated +from tessellate_ipu.lax import cummax_p, cummin_p, cumprod_p, cumsum_p + + +class IpuTilePrimitivesLaxCumsumTests(chex.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.device = jax.devices("ipu")[0] + self.num_tiles = self.device.num_tiles + # Not very clean, but for better reproducibility. + np.random.seed(123) + + @parameterized.parameters( + {"N": 16, "dtype": np.float32, "cumop": cumsum_p}, + {"N": 16, "dtype": np.int32, "cumop": cumsum_p}, + {"N": 16, "dtype": np.float32, "cumop": cummax_p}, + {"N": 16, "dtype": np.int32, "cumop": cummax_p}, + {"N": 16, "dtype": np.float32, "cumop": cummin_p}, + {"N": 16, "dtype": np.int32, "cumop": cummin_p}, + {"N": 16, "dtype": np.float32, "cumop": cumprod_p}, + {"N": 16, "dtype": np.int32, "cumop": cumprod_p}, + ) + def test__tile_map__cumulative_op__jitting__proper_result(self, N, dtype, cumop): + tiles = (0,) + data = (np.random.randn(N)).astype(dtype) + + def compute_fn(data): + data = tile_put_replicated(data, tiles) + return tile_map(cumop, data, axis=0, reverse=False) + + cpu_compute_fn = partial(jax.jit, backend="cpu")(compute_fn) + ipu_compute_fn = partial(jax.jit, backend="ipu")(compute_fn) + + cpu_output = cpu_compute_fn(data) + ipu_output = ipu_compute_fn(data) + assert ipu_output.tiles == tiles + assert ipu_output.dtype == data.dtype + npt.assert_array_almost_equal(ipu_output, cpu_output, decimal=5)