From faa8eea72be6c7a905274b06f2a8d4f80f8624f5 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 11 Oct 2023 14:55:40 +0000 Subject: [PATCH] wip --- .../core/vertex/intrinsics_utils.hpp | 27 ++++++----- .../core/vertex/ipu_model_types.hpp | 4 ++ tessellate_ipu/core/vertex/tile_small_dot.cpp | 12 ++++- tessellate_ipu/core/vertex/tile_small_dot.hpp | 47 ++++++++++++++----- tessellate_ipu/lax/tile_lax_small_dot.py | 5 +- tests/lax/test_tile_lax_small_dot.py | 39 ++++++++++++++- 6 files changed, 105 insertions(+), 29 deletions(-) diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index 38fa8dc..177b116 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -27,19 +27,24 @@ /** * Tag dispatching, between IPU model and IPU hardware implementations. * - * Making it hopefully easier to maintain IPU hardware and model implementations, - * without #ifdef/#endif preprocessor spaghetti code. + * Making it hopefully easier to maintain IPU hardware and model + * implementations, without #ifdef/#endif preprocessor spaghetti code. */ -namespace ipu -{ +namespace ipu { /** IPU hardware tag. */ -struct HardwareTag{}; +struct HardwareTag {}; /** IPU model tag. */ -struct ModelTag{}; -} +struct ModelTag {}; +} // namespace ipu +// IPU dispatch tag preprocessor. #ifdef __IPU__ +#define IPU_DISPATCH_TAG (ipu::HardwareTag{}) +#else +#define IPU_DISPATCH_TAG (ipu::ModelTag{}) +#endif +#ifdef __IPU__ /** * @brief Efficient division by 6, on IPU hardware. Up to 98,304. */ @@ -57,7 +62,7 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { R"l( uput $TAS, %[sv] )l" : - : [sv] "r"(v) + : [ sv ] "r"(v) :); } @@ -69,7 +74,7 @@ ALWAYS_INLINE void __builtin_ipu_f32v2cmac(float2 x, float2 y) noexcept { R"l( f32v2mac %[x], %[y] )l" : - : [x] "r"(x), [y] "r"(y) + : [ x ] "r"(x), [ y ] "r"(y) :); } @@ -80,8 +85,8 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) { asm volatile( R"l( ld32 %[result], %[address], %[offset] )l" - : [result] "=r"(result) - : [address] "r"(address), [offset] "r"(offset) + : [ result ] "=r"(result) + : [ address ] "r"(address), [ offset ] "r"(offset) :); return result; } diff --git a/tessellate_ipu/core/vertex/ipu_model_types.hpp b/tessellate_ipu/core/vertex/ipu_model_types.hpp index d6c9810..20b6782 100644 --- a/tessellate_ipu/core/vertex/ipu_model_types.hpp +++ b/tessellate_ipu/core/vertex/ipu_model_types.hpp @@ -1,6 +1,7 @@ // Copyright (c) 2023 Graphcore Ltd. All rights reserved. #pragma once +// Only defined on IPU model. #ifndef __IPU__ #include #include @@ -77,4 +78,7 @@ using uint4 = IpuVector; using long2 = IpuVector; using long4 = IpuVector; +// rptsize_t alias. +using rptsize_t = uint16_t; + #endif diff --git a/tessellate_ipu/core/vertex/tile_small_dot.cpp b/tessellate_ipu/core/vertex/tile_small_dot.cpp index 8c128d1..a06bcff 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.cpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.cpp @@ -4,6 +4,8 @@ #include #include +using namespace poplar; + // class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dVertex // : public MultiVertex { class Rotation2dVertex : public MultiVertex { @@ -22,7 +24,6 @@ class Rotation2dVertex : public MultiVertex { Input> inrow1; // (N,) second input row vector - Input> worker_offsets; // (7,) number threads + 1. @@ -37,6 +38,15 @@ class Rotation2dVertex : public MultiVertex { const IndexType wend = worker_offsets[wid + 1]; const IndexType wsize = wend - wstart; + const T2* inrow0_ptr = reinterpret_cast(inrow0.data()) + wstart; + const T2* inrow1_ptr = reinterpret_cast(inrow1.data()) + wstart; + const T2* cs_ptr = reinterpret_cast(cs.data()); + + T2* outrow0_ptr = reinterpret_cast(outrow0.data()) + wstart; + T2* outrow1_ptr = reinterpret_cast(outrow1.data()) + wstart; + + rotation2_float(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, + wsize, IPU_DISPATCH_TAG); return true; } diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index b69207b..d5eda09 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -9,15 +9,38 @@ R as(T x) { return *reinterpret_cast(&x); } - -inline void rotation2_float(const float2 *inrow0, const float2 *inrow1, - float2 cs, float2 *outrow0, float2 *outrow1, - rptsize_t nblocks, ipu::HardwareTag) {} - - -inline void rotation2_float(const float2 *inrow0, const float2 *inrow1, - float2 cs, float2 *outrow0, float2 *outrow1, - rptsize_t nblocks, ipu::ModelTag) { - - - } +/** + * @brief Apply 2d rotation transform. + */ +inline void rotation2_float(float2 cs, const float2 *inrow0, + const float2 *inrow1, float2 *outrow0, + float2 *outrow1, rptsize_t nblocks, + ipu::HardwareTag) { + using T2 = float2; + for (rptsize_t idx = 0; idx != nblocks; ++idx) { + const T2 invec0 = ipu::load_postinc(&inrow0, 1); + const T2 invec1 = ipu::load_postinc(&inrow1, 1); + // Vectorized apply 2d rotation. + const T2 outvec0 = cs[0] * invec0 - cs[1] * invec1; + const T2 outvec1 = cs[1] * invec0 + cs[0] * invec1; + ipu::store_postinc(&outrow0, outvec0, 1); + ipu::store_postinc(&outrow1, outvec1, 1); + } +} +inline void rotation2_float(float2 cs, const float2 *inrow0, + const float2 *inrow1, float2 *outrow0, + float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) { + using T2 = float2; + const T2 cvec = {cs[0], cs[0]}; + const T2 svec = {cs[1], cs[1]}; + // Note: supported on IPU model! + for (rptsize_t idx = 0; idx != nblocks; ++idx) { + const T2 invec0 = ipu::load_postinc(&inrow0, 1); + const T2 invec1 = ipu::load_postinc(&inrow1, 1); + // Vectorized apply 2d rotation. + const T2 outvec0 = cvec * invec0 - svec * invec1; + const T2 outvec1 = svec * invec0 + cvec * invec1; + ipu::store_postinc(&outrow0, outvec0, 1); + ipu::store_postinc(&outrow1, outvec1, 1); + } +} diff --git a/tessellate_ipu/lax/tile_lax_small_dot.py b/tessellate_ipu/lax/tile_lax_small_dot.py index a832d52..de96252 100644 --- a/tessellate_ipu/lax/tile_lax_small_dot.py +++ b/tessellate_ipu/lax/tile_lax_small_dot.py @@ -32,6 +32,7 @@ def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray): outrow1: Second output row (N,) """ N = inrow0.size + assert N % 4 == 0 assert inrow0 == inrow1 assert cs.dtype == inrow0.dtype assert cs.dtype == inrow1.dtype @@ -41,9 +42,7 @@ def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray): "outrow0": inrow0, "outrow1": inrow1, } - constants = { - "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16) - } + constants = {"worker_offsets": make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)} temps: Dict[str, Any] = {} perf_estimate = 100 return outputs, constants, temps, perf_estimate diff --git a/tests/lax/test_tile_lax_small_dot.py b/tests/lax/test_tile_lax_small_dot.py index 650f709..cb3c453 100644 --- a/tests/lax/test_tile_lax_small_dot.py +++ b/tests/lax/test_tile_lax_small_dot.py @@ -5,8 +5,43 @@ import jax import numpy as np import numpy.testing as npt -from absl.testing import parameterized +import pytest -from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded +from tessellate_ipu import tile_map, tile_put_replicated from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p + +@pytest.mark.ipu_hardware +class IpuTileRotation2dHwTests(chex.TestCase): + def setUp(self): + super().setUp() + np.random.seed(42) + + def test__tile_map__rotation2d_primitive__proper_result(self): + N = 8 + tiles = (0,) + indata = np.random.randn(2, N).astype(np.float32) + cs = np.random.randn(2).astype(np.float32) + rot2d = np.array([[cs[0], -cs[1]], [cs[1], cs[0]]]).astype(np.float32) + + def compute_fn(cs, row0, row1): + cs = tile_put_replicated(cs, tiles) + row0 = tile_put_replicated(row0, tiles) + row1 = tile_put_replicated(row1, tiles) + return tile_map(rotation2d_p, cs, row0, row1) + + # compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn) + compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) + + # output_cpu = compute_fn_cpu(indata) + outrow0, outrow1 = compute_fn_ipu(cs, indata[0], indata[1]) + expected_out = rot2d @ indata + + print(outrow0) + print(outrow1) + + # npt.assert_array_almost_equal(np.ravel(outrow0), indata[0], decimal=5) + # npt.assert_array_almost_equal(np.ravel(outrow1), indata[1], decimal=5) + + npt.assert_array_almost_equal(np.ravel(outrow0), expected_out[0], decimal=5) + npt.assert_array_almost_equal(np.ravel(outrow1), expected_out[1], decimal=5)