diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index 27ddb59..5ea61e8 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -24,8 +24,27 @@ // #define ALWAYS_INLINE __attribute__((always_inline)) #define ALWAYS_INLINE inline +/** + * Tag dispatching, between IPU model and IPU hardware implementations. + * + * Making it hopefully easier to maintain IPU hardware and model + * implementations, without #ifdef/#endif preprocessor spaghetti code. + */ +namespace ipu { +/** IPU hardware tag. */ +struct HardwareTag {}; +/** IPU model tag. */ +struct ModelTag {}; +} // namespace ipu + +// IPU dispatch tag preprocessor. #ifdef __IPU__ +#define IPU_DISPATCH_TAG (ipu::HardwareTag{}) +#else +#define IPU_DISPATCH_TAG (ipu::ModelTag{}) +#endif +#ifdef __IPU__ /** * @brief Efficient division by 6, on IPU hardware. Up to 98,304. */ @@ -43,7 +62,7 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { R"l( uput $TAS, %[sv] )l" : - : [sv] "r"(v) + : [ sv ] "r"(v) :); } @@ -55,7 +74,7 @@ ALWAYS_INLINE void __builtin_ipu_f32v2cmac(float2 x, float2 y) noexcept { R"l( f32v2mac %[x], %[y] )l" : - : [x] "r"(x), [y] "r"(y) + : [ x ] "r"(x), [ y ] "r"(y) :); } @@ -66,8 +85,8 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) { asm volatile( R"l( ld32 %[result], %[address], %[offset] )l" - : [result] "=r"(result) - : [address] "r"(address), [offset] "r"(offset) + : [ result ] "=r"(result) + : [ address ] "r"(address), [ offset ] "r"(offset) :); return result; } @@ -171,3 +190,11 @@ T __builtin_ipu_f32v2axpy(T const& x, T const& y) { // clang-format on #endif + +/** + * @brief Bitwise cast to a different type. + */ +template +R as(T x) { + return *reinterpret_cast(&x); +} diff --git a/tessellate_ipu/core/vertex/ipu_model_types.hpp b/tessellate_ipu/core/vertex/ipu_model_types.hpp index d6c9810..727cb5d 100644 --- a/tessellate_ipu/core/vertex/ipu_model_types.hpp +++ b/tessellate_ipu/core/vertex/ipu_model_types.hpp @@ -1,6 +1,7 @@ // Copyright (c) 2023 Graphcore Ltd. All rights reserved. #pragma once +// Only defined on IPU model. #ifndef __IPU__ #include #include @@ -77,4 +78,7 @@ using uint4 = IpuVector; using long2 = IpuVector; using long4 = IpuVector; +// rptsize_t alias on IPU model. +using rptsize_t = uint16_t; + #endif diff --git a/tessellate_ipu/core/vertex/tile_small_dot.cpp b/tessellate_ipu/core/vertex/tile_small_dot.cpp new file mode 100644 index 0000000..4678590 --- /dev/null +++ b/tessellate_ipu/core/vertex/tile_small_dot.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2023 Graphcore Ltd. All rights reserved. +#include "tile_small_dot.hpp" + +#include +#include + +using namespace poplar; + +/** + * @brief 2d rotation vertex. + */ +class Rotation2dVertex : public MultiVertex { + public: + using T = float; + using T2 = float2; + // Using `uint16` seems to be generating more efficient loops? + using IndexType = unsigned short; + + static constexpr size_t MIN_ALIGN = 8; + + Input> + cs; // (2,) rotation cosinus/sinus values + Input> + inrow0; // (N,) first input row vector + Input> + inrow1; // (N,) second input row vector + + Input> + worker_offsets; // (7,) number threads + 1. + + Output> + outrow0; // (N,) first input row vector + Output> + outrow1; // (N,) first input row vector + + bool compute(unsigned wid) { + // vectorized offsets. + const IndexType wstart = worker_offsets[wid]; + const IndexType wend = worker_offsets[wid + 1]; + const IndexType wsize = wend - wstart; + + // Vertex inputs/outputs assuring proper alignment. + const T2* inrow0_ptr = reinterpret_cast(inrow0.data()) + wstart; + const T2* inrow1_ptr = reinterpret_cast(inrow1.data()) + wstart; + const T2* cs_ptr = reinterpret_cast(cs.data()); + T2* outrow0_ptr = reinterpret_cast(outrow0.data()) + wstart; + T2* outrow1_ptr = reinterpret_cast(outrow1.data()) + wstart; + + rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, + wsize, IPU_DISPATCH_TAG); + return true; + } +}; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp new file mode 100644 index 0000000..1c0ff3e --- /dev/null +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -0,0 +1,81 @@ +// Copyright (c) 2023 Graphcore Ltd. All rights reserved. +#include "intrinsics_utils.hpp" + +/** + * @brief z = a*x + b*y float32 implementation. + * + * where x, y, z are 1D arrays and a, b are scalars. + * Implementation compatible with IPU model and hardware. + * + * Requires input arrays with size % 2 == 0 + */ +inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y, + float2 *z, rptsize_t nblocks) { + using T2 = float2; + const T2 av = {a, a}; + const T2 bv = {b, b}; + // Sub-optimal vectorized implementation. + for (unsigned idx = 0; idx < nblocks; ++idx) { + const T2 xv = ipu::load_postinc(&x, 1); + const T2 yv = ipu::load_postinc(&y, 1); + const T2 zv = av * xv + bv * yv; + ipu::store_postinc(&z, zv, 1); + } +} + +void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, + float2 *z, rptsize_t nblocks) { + // Necessary if using unsigned `nblocks`. + // __builtin_assume(nblocks < 4096); + using T2 = float2; + const T2 av = {a, a}; + // Using TAS register for one of the scalar. + __ipu_and_ipumodel_tas tas; + tas.put(b); + + T2 res, xv, yv, zv, tmp; + + xv = ipu::load_postinc(&x, 1); + yv = ipu::load_postinc(&y, 1); + res = xv * av; + for (unsigned idx = 0; idx != nblocks; ++idx) { + // Pseudo dual-issuing of instructions. + // popc should be able to generate an optimal rpt loop. + { + xv = ipu::load_postinc(&x, 1); + // TODO: fix ordering of arguments in `f32v2axpy`. + tmp = tas.f32v2axpy(res, yv); + } + { + yv = ipu::load_postinc(&y, 1); + // TODO: fix ordering of arguments in `f32v2axpy`. + zv = tas.f32v2axpy(tmp, tmp); + } + { + ipu::store_postinc(&z, zv, 1); + res = xv * av; + } + } +} + +/** + * @brief Apply 2d rotation transform (float). + * + * Note: input rows are separated, allowing more flexibility + * for functions/vertices using this base compute method. + */ +inline void rotation2d_f32(float2 cs, const float2 *inrow0, + const float2 *inrow1, float2 *outrow0, + float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) { + axplusby_f32_v1(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); + axplusby_f32_v1(cs[1], cs[0], inrow0, inrow1, outrow1, nblocks); +} + +inline void rotation2d_f32(float2 cs, const float2 *inrow0, + const float2 *inrow1, float2 *outrow0, + float2 *outrow1, rptsize_t nblocks, + ipu::HardwareTag) { + // Using same implementation as IPU model for now. + rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks, + ipu::ModelTag{}); +} diff --git a/tessellate_ipu/lax/__init__.py b/tessellate_ipu/lax/__init__.py index 60b7bd8..bf4b15d 100644 --- a/tessellate_ipu/lax/__init__.py +++ b/tessellate_ipu/lax/__init__.py @@ -43,6 +43,7 @@ from .tile_lax_dot import IpuConvVertexType from .tile_lax_gather import gather_p from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p +from .tile_lax_small_dot import rotation2d_p from .tile_lax_unary import ( # tanh_inplace_p, abs_inplace_p, asin_inplace_p, diff --git a/tessellate_ipu/lax/tile_lax_small_dot.py b/tessellate_ipu/lax/tile_lax_small_dot.py new file mode 100644 index 0000000..0a4076a --- /dev/null +++ b/tessellate_ipu/lax/tile_lax_small_dot.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import os +from typing import Any, Dict + +import numpy as np +from jax.core import ShapedArray + +from tessellate_ipu.core import declare_ipu_tile_primitive +from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets + + +def get_small_dot_vertex_gp_filename() -> str: + return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_small_dot.cpp") + + +@declare_ipu_tile_primitive("Rotation2dVertex", gp_filename=get_small_dot_vertex_gp_filename()) +def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray): + """2d rotation apply primitive. + + Specific optimization on IPU backend compared to `dot_general_p` primitive. + In particular, allows passing the 2 rows of the (2, N) input as separate arrays (in some + applications, contiguous storage may not be possible). + + Args: + cs: Cos/sin 2d rotation entries. + inrow0: First row (N,) + inrow1: Second row (N,) + Returns: + outrow0: First output row (N,) + outrow1: Second output row (N,) + """ + N = inrow0.size + assert N % 2 == 0 + assert inrow0 == inrow1 + assert cs.dtype == inrow0.dtype + assert cs.dtype == inrow1.dtype + assert inrow0.dtype == np.float32 + + outputs = { + "outrow0": inrow0, + "outrow1": inrow1, + } + constants = {"worker_offsets": make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)} + temps: Dict[str, Any] = {} + perf_estimate = 100 + return outputs, constants, temps, perf_estimate diff --git a/tests/lax/test_tile_lax_small_dot.py b/tests/lax/test_tile_lax_small_dot.py new file mode 100644 index 0000000..e21cdf3 --- /dev/null +++ b/tests/lax/test_tile_lax_small_dot.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial + +import chex +import jax +import numpy as np +import numpy.testing as npt +import pytest + +from tessellate_ipu import ipu_cycle_count, tile_map, tile_put_replicated +from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p + + +@pytest.mark.ipu_hardware +class IpuTileRotation2dHwTests(chex.TestCase): + def setUp(self): + super().setUp() + np.random.seed(42) + + def test__tile_map__rotation2d_primitive__proper_result_and_cycle_count(self): + N = 512 + tiles = (0,) + indata = np.random.randn(2, N).astype(np.float32) + cs = np.random.randn(2).astype(np.float32) + rot2d = np.array([[cs[0], -cs[1]], [cs[1], cs[0]]]).astype(np.float32) + + def compute_fn(cs, row0, row1): + cs = tile_put_replicated(cs, tiles) + row0 = tile_put_replicated(row0, tiles) + row1 = tile_put_replicated(row1, tiles) + # Benchmark the raw 2d rotation vertex. + cs, row0, row1, start = ipu_cycle_count(cs, row0, row1) + outrow0, outrow1 = tile_map(rotation2d_p, cs, row0, row1) # type:ignore + outrow0, outrow1, end = ipu_cycle_count(outrow0, outrow1) + + return outrow0, outrow1, start, end + + compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) + outrow0, outrow1, start, end = compute_fn_ipu(cs, indata[0], indata[1]) + + # Checking getting the proper result! + expected_out = rot2d @ indata + npt.assert_array_almost_equal(np.ravel(outrow0), expected_out[0], decimal=6) + npt.assert_array_almost_equal(np.ravel(outrow1), expected_out[1], decimal=6) + # Hardware cycle count bound. + start, end = np.asarray(start)[0], np.asarray(end)[0] + hw_cycle_count = end[0] - start[0] + # Observe on IPU Mk2 hw ~1916 cycles. + assert hw_cycle_count <= 2000