Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement and integrate cumsum/max/min/prod TessellateIPU vertex. #37

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@
| `ceil` | :white_check_mark: | :white_check_mark: | |
| `clamp` | :white_check_mark: | :x: | |
| `collapse` | :white_check_mark: | :white_check_mark: | |
| `complex` | :x: | n/a | Complex not supported in IPU XLA backend |
| `complex` | :x: | n/a | `complex` dtype not supported in IPU XLA backend |
| `concatenate` | :white_check_mark: | n/a | |
| `conj` | :x: | n/a | Complex not supported in IPU XLA backend |
| `conj` | :x: | n/a | `complex` dtype not supported in IPU XLA backend |
| `conv` | :x: | n/a | |
| `convert_element_type` | :white_check_mark: | :x: | |
| `conv_general_dilated` | :x: | n/a | |
| `conv_transpose` | :x: | n/a | |
| `cos` | :white_check_mark: | :white_check_mark: | |
| `cosh` | :white_check_mark: | :white_check_mark: | |
| `cummax` | :x: | :x: | |
| `cummin` | :x: | :x: | |
| `cumprod` | :x: | :x: | |
| `cumsum` | :x: | :x: | |
| `cummax` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `cummin` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `cumprod` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `cumsum` | :white_check_mark: | :x: | Only supporting 1D input tensor |
| `digamma` | :x: | :x: | |
| `div` | :white_check_mark: | :white_check_mark: | |
| `dot` | :white_check_mark: | n/a | |
Expand All @@ -67,7 +67,7 @@
| `gt` | :white_check_mark: | n/a | |
| `igamma` | :x: | :x: | |
| `igammac` | :x: | :x: | |
| `imag` | :x: | :x: | |
| `imag` | :x: | :x: | `complex` dtype not supported in IPU XLA backend |
| `index_in_dim` | :x: | n/a | |
| `index_take` | :x: | n/a | |
| `iota` | :white_check_mark: | n/a | |
Expand Down
72 changes: 72 additions & 0 deletions tessellate_ipu/core/vertex/tile_cumulative_ops_vertex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#include <algorithm>
#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
using namespace poplar;

namespace tl {

enum OpType : int { SUM = 0, MIN = 1, MAX = 2, PROD = 3 };

template <typename T, int OP>
T initial_accumulator_value() {
if constexpr (OP == SUM) {
return T(0);
} else if constexpr (OP == MIN) {
return std::numeric_limits<T>::max();
} else if constexpr (OP == MAX) {
return std::numeric_limits<T>::lowest();
} else {
return T(1);
}
}
template <typename T, int OP>
T cumulative_op(T acc, T rhs) {
if constexpr (OP == SUM) {
return acc + rhs;
} else if constexpr (OP == MIN) {
return std::min(acc, rhs);
} else if constexpr (OP == MAX) {
return std::max(acc, rhs);
} else {
return acc * rhs;
}
}

/**
* @brief Cumulative op vertex.
* Very simple implementation at first, no big optimization!
*/
template <typename T, int OP>
class CumulativeOp : public Vertex {
public:
Input<Vector<T, poplar::VectorLayout::SPAN>> in;
Output<Vector<T, poplar::VectorLayout::ONE_PTR>> out;

bool compute() {
T accumulator = initial_accumulator_value<T, OP>();
const int32_t size = in.size();
for (int32_t idx = 0; idx < size; ++idx) {
accumulator = cumulative_op<T, OP>(accumulator, in[idx]);
out[idx] = accumulator;
}
return true;
}
};

// explicit instantiations
template class CumulativeOp<int, SUM>;
template class CumulativeOp<float, SUM>;

template class CumulativeOp<int, MIN>;
template class CumulativeOp<float, MIN>;

template class CumulativeOp<int, MAX>;
template class CumulativeOp<float, MAX>;

template class CumulativeOp<int, PROD>;
template class CumulativeOp<float, PROD>;

} // namespace tl
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
scaled_sub_p,
sub_inplace_p,
)
from .tile_lax_cumulative_ops import cummax_p, cummin_p, cumprod_p, cumsum_p
from .tile_lax_dot import IpuConvVertexType
from .tile_lax_gather import gather_p
from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p
Expand Down
78 changes: 78 additions & 0 deletions tessellate_ipu/lax/tile_lax_cumulative_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import os
from typing import Any, Dict, List, Tuple

from jax.core import Primitive, ShapedArray
from jax.lax import cummax_p, cummin_p, cumprod_p, cumsum_p

from tessellate_ipu.core import (
IpuTileMapEquation,
make_ipu_vertex_in_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
)

_cumop_primitive_to_opcode: Dict[Primitive, int] = {
cumsum_p: 0,
cummin_p: 1,
cummax_p: 2,
cumprod_p: 3,
}


def get_cumulative_ops_gp_filename() -> str:
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_cumulative_ops_vertex.cpp")


def ipu_cumop_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU cumulative ops LAX primitive translation rule to IPU vertex.

Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 1
assert attributes is not None
inaval = inavals[0]
axis = attributes.get("axis", 0)
reverse = attributes.get("reverse", False)

# Subset of configuration supported.
assert axis == 0
assert not reverse
assert len(inaval.shape) == 1
outaval = inaval
# TessellateIPU custom cumulative vertices.
opcode = _cumop_primitive_to_opcode[p]
vname = make_ipu_vertex_name_templated("tl::CumulativeOp", inaval.dtype, opcode)
inputs_info = [make_ipu_vertex_in_info("in", inaval)]
outputs_info = [make_ipu_vertex_out_info("out", outaval)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
gp_filename=get_cumulative_ops_gp_filename(),
perf_estimate=inaval.size * 6,
)
return ipu_prim_info


# Register JAX LAX cumulative-ops primitive.
register_ipu_tile_primitive(cumsum_p, ipu_cumop_primitive_translation)
register_ipu_tile_primitive(cummax_p, ipu_cumop_primitive_translation)
register_ipu_tile_primitive(cummin_p, ipu_cumop_primitive_translation)
register_ipu_tile_primitive(cumprod_p, ipu_cumop_primitive_translation)
47 changes: 47 additions & 0 deletions tests/lax/test_tile_lax_cumulative_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import tile_map, tile_put_replicated
from tessellate_ipu.lax import cummax_p, cummin_p, cumprod_p, cumsum_p


class IpuTilePrimitivesLaxCumsumTests(chex.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.device = jax.devices("ipu")[0]
self.num_tiles = self.device.num_tiles
# Not very clean, but for better reproducibility.
np.random.seed(123)

@parameterized.parameters(
{"N": 16, "dtype": np.float32, "cumop": cumsum_p},
{"N": 16, "dtype": np.int32, "cumop": cumsum_p},
{"N": 16, "dtype": np.float32, "cumop": cummax_p},
{"N": 16, "dtype": np.int32, "cumop": cummax_p},
{"N": 16, "dtype": np.float32, "cumop": cummin_p},
{"N": 16, "dtype": np.int32, "cumop": cummin_p},
{"N": 16, "dtype": np.float32, "cumop": cumprod_p},
{"N": 16, "dtype": np.int32, "cumop": cumprod_p},
)
def test__tile_map__cumulative_op__jitting__proper_result(self, N, dtype, cumop):
tiles = (0,)
data = (np.random.randn(N)).astype(dtype)

def compute_fn(data):
data = tile_put_replicated(data, tiles)
return tile_map(cumop, data, axis=0, reverse=False)

cpu_compute_fn = partial(jax.jit, backend="cpu")(compute_fn)
ipu_compute_fn = partial(jax.jit, backend="ipu")(compute_fn)

cpu_output = cpu_compute_fn(data)
ipu_output = ipu_compute_fn(data)
assert ipu_output.tiles == tiles
assert ipu_output.dtype == data.dtype
npt.assert_array_almost_equal(ipu_output, cpu_output, decimal=5)