Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement and integrate cumsum/max/min/prod TessellateIPU vertex.
Browse files Browse the repository at this point in the history
Only supporting 1D input arrays at the moment.
Note: performance is not yet optimal, as using a single thread + no
vectorized instructions.
balancap committed Sep 28, 2023
1 parent 2e9475d commit d5dfd3c
Showing 5 changed files with 205 additions and 7 deletions.
14 changes: 7 additions & 7 deletions docs/operations.md
Original file line number Diff line number Diff line change
@@ -32,19 +32,19 @@
| `ceil` | :white_check_mark: | :white_check_mark: | |
| `clamp` | :white_check_mark: | :x: | |
| `collapse` | :white_check_mark: | :white_check_mark: | |
| `complex` | :x: | n/a | Complex not supported in IPU XLA backend |
| `complex` | :x: | n/a | `complex` dtype not supported in IPU XLA backend |
| `concatenate` | :white_check_mark: | n/a | |
| `conj` | :x: | n/a | Complex not supported in IPU XLA backend |
| `conj` | :x: | n/a | `complex` dtype not supported in IPU XLA backend |
| `conv` | :x: | n/a | |
| `convert_element_type` | :white_check_mark: | :x: | |
| `conv_general_dilated` | :x: | n/a | |
| `conv_transpose` | :x: | n/a | |
| `cos` | :white_check_mark: | :white_check_mark: | |
| `cosh` | :white_check_mark: | :white_check_mark: | |
| `cummax` | :x: | :x: | |
| `cummin` | :x: | :x: | |
| `cumprod` | :x: | :x: | |
| `cumsum` | :x: | :x: | |
| `cummax` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `cummin` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `cumprod` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `cumsum` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `digamma` | :x: | :x: | |
| `div` | :white_check_mark: | :white_check_mark: | |
| `dot` | :white_check_mark: | n/a | |
@@ -67,7 +67,7 @@
| `gt` | :white_check_mark: | n/a | |
| `igamma` | :x: | :x: | |
| `igammac` | :x: | :x: | |
| `imag` | :x: | :x: | |
| `imag` | :x: | :x: | `complex` dtype not supported in IPU XLA backend |
| `index_in_dim` | :x: | n/a | |
| `index_take` | :x: | n/a | |
| `iota` | :white_check_mark: | n/a | |
72 changes: 72 additions & 0 deletions tessellate_ipu/core/vertex/tile_cumulative_ops_vertex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#include <algorithm>
#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
using namespace poplar;

namespace tl {

enum OpType : int { SUM = 0, MIN = 1, MAX = 2, PROD = 3 };

template <typename T, int OP>
T initial_accumulator_value() {
if constexpr (OP == SUM) {
return T(0);
} else if constexpr (OP == MIN) {
return std::numeric_limits<T>::max();
} else if constexpr (OP == MAX) {
return std::numeric_limits<T>::lowest();
} else {
return T(1);
}
}
template <typename T, int OP>
T cumulative_op(T acc, T rhs) {
if constexpr (OP == SUM) {
return acc + rhs;
} else if constexpr (OP == MIN) {
return std::min(acc, rhs);
} else if constexpr (OP == MAX) {
return std::max(acc, rhs);
} else {
return acc * rhs;
}
}

/**
* @brief Cumulative op vertex.
* Very simple implementation at first, no big optimization!
*/
template <typename T, int OP>
class CumulativeOp : public Vertex {
public:
Input<Vector<T, poplar::VectorLayout::SPAN>> in;
Output<Vector<T, poplar::VectorLayout::ONE_PTR>> out;

bool compute() {
T accumulator = initial_accumulator_value<T, OP>();
const int32_t size = in.size();
for (int32_t idx = 0; idx < size; ++idx) {
accumulator = cumulative_op<T, OP>(accumulator, in[idx]);
out[idx] = accumulator;
}
return true;
}
};

// explicit instantiations
template class CumulativeOp<int, SUM>;
template class CumulativeOp<float, SUM>;

template class CumulativeOp<int, MIN>;
template class CumulativeOp<float, MIN>;

template class CumulativeOp<int, MAX>;
template class CumulativeOp<float, MAX>;

template class CumulativeOp<int, PROD>;
template class CumulativeOp<float, PROD>;

} // namespace tl
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@
scaled_sub_p,
sub_inplace_p,
)
from .tile_lax_cumulative_ops import cummax_p, cummin_p, cumprod_p, cumsum_p
from .tile_lax_dot import IpuConvVertexType
from .tile_lax_gather import gather_p
from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p
78 changes: 78 additions & 0 deletions tessellate_ipu/lax/tile_lax_cumulative_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import os
from typing import Any, Dict, List, Tuple

from jax.core import Primitive, ShapedArray
from jax.lax import cummax_p, cummin_p, cumprod_p, cumsum_p

from tessellate_ipu.core import (
IpuTileMapEquation,
make_ipu_vertex_in_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
)

_cumop_primitive_to_opcode: Dict[Primitive, int] = {
cumsum_p: 0,
cummin_p: 1,
cummax_p: 2,
cumprod_p: 3,
}


def get_cumulative_ops_gp_filename() -> str:
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_cumulative_ops_vertex.cpp")


def ipu_cumop_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU cumulative ops LAX primitive translation rule to IPU vertex.
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 1
assert attributes is not None
inaval = inavals[0]
axis = attributes.get("axis", 0)
reverse = attributes.get("reverse", False)

# Subset of configuration supported.
assert axis == 0
assert not reverse
assert len(inaval.shape) == 1
outaval = inaval
# TessellateIPU custom cumulative vertices.
opcode = _cumop_primitive_to_opcode[p]
vname = make_ipu_vertex_name_templated("tl::CumulativeOp", inaval.dtype, opcode)
inputs_info = [make_ipu_vertex_in_info("in", inaval)]
outputs_info = [make_ipu_vertex_out_info("out", outaval)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
gp_filename=get_cumulative_ops_gp_filename(),
perf_estimate=inaval.size * 6,
)
return ipu_prim_info


# Register JAX LAX cumulative-ops primitive.
register_ipu_tile_primitive(cumsum_p, ipu_cumop_primitive_translation)
register_ipu_tile_primitive(cummax_p, ipu_cumop_primitive_translation)
register_ipu_tile_primitive(cummin_p, ipu_cumop_primitive_translation)
register_ipu_tile_primitive(cumprod_p, ipu_cumop_primitive_translation)
47 changes: 47 additions & 0 deletions tests/lax/test_tile_lax_cumulative_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import tile_map, tile_put_replicated
from tessellate_ipu.lax import cummax_p, cummin_p, cumprod_p, cumsum_p


class IpuTilePrimitivesLaxCumsumTests(chex.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.device = jax.devices("ipu")[0]
self.num_tiles = self.device.num_tiles
# Not very clean, but for better reproducibility.
np.random.seed(123)

@parameterized.parameters(
{"N": 16, "dtype": np.float32, "cumop": cumsum_p},
{"N": 16, "dtype": np.int32, "cumop": cumsum_p},
{"N": 16, "dtype": np.float32, "cumop": cummax_p},
{"N": 16, "dtype": np.int32, "cumop": cummax_p},
{"N": 16, "dtype": np.float32, "cumop": cummin_p},
{"N": 16, "dtype": np.int32, "cumop": cummin_p},
{"N": 16, "dtype": np.float32, "cumop": cumprod_p},
{"N": 16, "dtype": np.int32, "cumop": cumprod_p},
)
def test__tile_map__cumulative_op__jitting__proper_result(self, N, dtype, cumop):
tiles = (0,)
data = (np.random.randn(N)).astype(dtype)

def compute_fn(data):
data = tile_put_replicated(data, tiles)
return tile_map(cumop, data, axis=0, reverse=False)

cpu_compute_fn = partial(jax.jit, backend="cpu")(compute_fn)
ipu_compute_fn = partial(jax.jit, backend="ipu")(compute_fn)

cpu_output = cpu_compute_fn(data)
ipu_output = ipu_compute_fn(data)
assert ipu_output.tiles == tiles
assert ipu_output.dtype == data.dtype
npt.assert_array_almost_equal(ipu_output, cpu_output, decimal=5)

0 comments on commit d5dfd3c

Please sign in to comment.