-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement and integrate cumsum/max/min/prod TessellateIPU vertex.
Only supporting 1D input arrays at the moment. Note: performance is not yet optimal, as using a single thread + no vectorized instructions.
Showing
5 changed files
with
205 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (c) 2022 Graphcore Ltd. All rights reserved. | ||
#include <algorithm> | ||
#include <poplar/HalfFloat.hpp> | ||
#include <poplar/Vertex.hpp> | ||
|
||
#include "intrinsics_utils.hpp" | ||
using namespace poplar; | ||
|
||
namespace tl { | ||
|
||
enum OpType : int { SUM = 0, MIN = 1, MAX = 2, PROD = 3 }; | ||
|
||
template <typename T, int OP> | ||
T initial_accumulator_value() { | ||
if constexpr (OP == SUM) { | ||
return T(0); | ||
} else if constexpr (OP == MIN) { | ||
return std::numeric_limits<T>::max(); | ||
} else if constexpr (OP == MAX) { | ||
return std::numeric_limits<T>::lowest(); | ||
} else { | ||
return T(1); | ||
} | ||
} | ||
template <typename T, int OP> | ||
T cumulative_op(T acc, T rhs) { | ||
if constexpr (OP == SUM) { | ||
return acc + rhs; | ||
} else if constexpr (OP == MIN) { | ||
return std::min(acc, rhs); | ||
} else if constexpr (OP == MAX) { | ||
return std::max(acc, rhs); | ||
} else { | ||
return acc * rhs; | ||
} | ||
} | ||
|
||
/** | ||
* @brief Cumulative op vertex. | ||
* Very simple implementation at first, no big optimization! | ||
*/ | ||
template <typename T, int OP> | ||
class CumulativeOp : public Vertex { | ||
public: | ||
Input<Vector<T, poplar::VectorLayout::SPAN>> in; | ||
Output<Vector<T, poplar::VectorLayout::ONE_PTR>> out; | ||
|
||
bool compute() { | ||
T accumulator = initial_accumulator_value<T, OP>(); | ||
const int32_t size = in.size(); | ||
for (int32_t idx = 0; idx < size; ++idx) { | ||
accumulator = cumulative_op<T, OP>(accumulator, in[idx]); | ||
out[idx] = accumulator; | ||
} | ||
return true; | ||
} | ||
}; | ||
|
||
// explicit instantiations | ||
template class CumulativeOp<int, SUM>; | ||
template class CumulativeOp<float, SUM>; | ||
|
||
template class CumulativeOp<int, MIN>; | ||
template class CumulativeOp<float, MIN>; | ||
|
||
template class CumulativeOp<int, MAX>; | ||
template class CumulativeOp<float, MAX>; | ||
|
||
template class CumulativeOp<int, PROD>; | ||
template class CumulativeOp<float, PROD>; | ||
|
||
} // namespace tl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import os | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from jax.core import Primitive, ShapedArray | ||
from jax.lax import cummax_p, cummin_p, cumprod_p, cumsum_p | ||
|
||
from tessellate_ipu.core import ( | ||
IpuTileMapEquation, | ||
make_ipu_vertex_in_info, | ||
make_ipu_vertex_name_templated, | ||
make_ipu_vertex_out_info, | ||
register_ipu_tile_primitive, | ||
) | ||
|
||
_cumop_primitive_to_opcode: Dict[Primitive, int] = { | ||
cumsum_p: 0, | ||
cummin_p: 1, | ||
cummax_p: 2, | ||
cumprod_p: 3, | ||
} | ||
|
||
|
||
def get_cumulative_ops_gp_filename() -> str: | ||
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_cumulative_ops_vertex.cpp") | ||
|
||
|
||
def ipu_cumop_primitive_translation( | ||
p: Primitive, | ||
tiles: Tuple[int, ...], | ||
inavals: List[ShapedArray], | ||
attributes: Dict[str, Any] = None, | ||
) -> IpuTileMapEquation: | ||
"""IPU cumulative ops LAX primitive translation rule to IPU vertex. | ||
Args: | ||
p: JAX primitive. | ||
tiles: Collection of tiles. | ||
inavals: Input shaped arrays. | ||
attributes: (unused) attributes. | ||
Returns: | ||
IPU tile map primitive structure. | ||
""" | ||
assert len(inavals) == 1 | ||
assert attributes is not None | ||
inaval = inavals[0] | ||
axis = attributes.get("axis", 0) | ||
reverse = attributes.get("reverse", False) | ||
|
||
# Subset of configuration supported. | ||
assert axis == 0 | ||
assert not reverse | ||
assert len(inaval.shape) == 1 | ||
outaval = inaval | ||
# TessellateIPU custom cumulative vertices. | ||
opcode = _cumop_primitive_to_opcode[p] | ||
vname = make_ipu_vertex_name_templated("tl::CumulativeOp", inaval.dtype, opcode) | ||
inputs_info = [make_ipu_vertex_in_info("in", inaval)] | ||
outputs_info = [make_ipu_vertex_out_info("out", outaval)] | ||
ipu_prim_info = IpuTileMapEquation( | ||
vname=vname, | ||
pname=p.name, | ||
tiles=tiles, | ||
inputs_info=inputs_info, | ||
outputs_info=outputs_info, | ||
attributes_i32=[], | ||
attributes_f32=[], | ||
gp_filename=get_cumulative_ops_gp_filename(), | ||
perf_estimate=inaval.size * 6, | ||
) | ||
return ipu_prim_info | ||
|
||
|
||
# Register JAX LAX cumulative-ops primitive. | ||
register_ipu_tile_primitive(cumsum_p, ipu_cumop_primitive_translation) | ||
register_ipu_tile_primitive(cummax_p, ipu_cumop_primitive_translation) | ||
register_ipu_tile_primitive(cummin_p, ipu_cumop_primitive_translation) | ||
register_ipu_tile_primitive(cumprod_p, ipu_cumop_primitive_translation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import chex | ||
import jax | ||
import numpy as np | ||
import numpy.testing as npt | ||
from absl.testing import parameterized | ||
|
||
from tessellate_ipu import tile_map, tile_put_replicated | ||
from tessellate_ipu.lax import cummax_p, cummin_p, cumprod_p, cumsum_p | ||
|
||
|
||
class IpuTilePrimitivesLaxCumsumTests(chex.TestCase, parameterized.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.device = jax.devices("ipu")[0] | ||
self.num_tiles = self.device.num_tiles | ||
# Not very clean, but for better reproducibility. | ||
np.random.seed(123) | ||
|
||
@parameterized.parameters( | ||
{"N": 16, "dtype": np.float32, "cumop": cumsum_p}, | ||
{"N": 16, "dtype": np.int32, "cumop": cumsum_p}, | ||
{"N": 16, "dtype": np.float32, "cumop": cummax_p}, | ||
{"N": 16, "dtype": np.int32, "cumop": cummax_p}, | ||
{"N": 16, "dtype": np.float32, "cumop": cummin_p}, | ||
{"N": 16, "dtype": np.int32, "cumop": cummin_p}, | ||
{"N": 16, "dtype": np.float32, "cumop": cumprod_p}, | ||
{"N": 16, "dtype": np.int32, "cumop": cumprod_p}, | ||
) | ||
def test__tile_map__cumulative_op__jitting__proper_result(self, N, dtype, cumop): | ||
tiles = (0,) | ||
data = (np.random.randn(N)).astype(dtype) | ||
|
||
def compute_fn(data): | ||
data = tile_put_replicated(data, tiles) | ||
return tile_map(cumop, data, axis=0, reverse=False) | ||
|
||
cpu_compute_fn = partial(jax.jit, backend="cpu")(compute_fn) | ||
ipu_compute_fn = partial(jax.jit, backend="ipu")(compute_fn) | ||
|
||
cpu_output = cpu_compute_fn(data) | ||
ipu_output = ipu_compute_fn(data) | ||
assert ipu_output.tiles == tiles | ||
assert ipu_output.dtype == data.dtype | ||
npt.assert_array_almost_equal(ipu_output, cpu_output, decimal=5) |