From 055805c67ae615c91e851bddf1e32f8b88761354 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 12 Oct 2023 19:10:33 +0000 Subject: [PATCH] Additional optimization to `axplusby_f32` using inline assembly. The function `axplusby_f32_v2` implements a fully optimized `rpt` loop in inline assembly, taking advantage of the instruction `ldst64pace`. --- .../core/vertex/intrinsics_utils.hpp | 10 +- tessellate_ipu/core/vertex/tile_small_dot.cpp | 8 +- tessellate_ipu/core/vertex/tile_small_dot.hpp | 94 ++++++++++++++++--- tests/lax/test_tile_lax_small_dot.py | 4 +- 4 files changed, 94 insertions(+), 22 deletions(-) diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index 5ea61e8..a254777 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -32,15 +32,21 @@ */ namespace ipu { /** IPU hardware tag. */ -struct HardwareTag {}; +struct HardwareTag { + static constexpr bool hardware = true; +}; /** IPU model tag. */ -struct ModelTag {}; +struct ModelTag { + static constexpr bool model = true; +}; } // namespace ipu // IPU dispatch tag preprocessor. #ifdef __IPU__ +#define IPU_TAG_TYPE ipu::HardwareTag #define IPU_DISPATCH_TAG (ipu::HardwareTag{}) #else +#define IPU_TAG_TYPE ipu::ModelTag #define IPU_DISPATCH_TAG (ipu::ModelTag{}) #endif diff --git a/tessellate_ipu/core/vertex/tile_small_dot.cpp b/tessellate_ipu/core/vertex/tile_small_dot.cpp index 4678590..2391ca1 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.cpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.cpp @@ -9,7 +9,9 @@ using namespace poplar; /** * @brief 2d rotation vertex. */ -class Rotation2dVertex : public MultiVertex { +class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)", + "elem(*inrow1) != elem(*outrow1)")]] Rotation2dVertex + : public MultiVertex { public: using T = float; using T2 = float2; @@ -46,8 +48,8 @@ class Rotation2dVertex : public MultiVertex { T2* outrow0_ptr = reinterpret_cast(outrow0.data()) + wstart; T2* outrow1_ptr = reinterpret_cast(outrow1.data()) + wstart; - rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, - wsize, IPU_DISPATCH_TAG); + rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, + outrow1_ptr, wsize); return true; } }; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index a42dcac..26bd338 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -22,9 +22,15 @@ inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y, ipu::store_postinc(&z, zv, 1); } } - -inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, - float2 *z, rptsize_t nblocks) { +/** + * @brief z = a*x + b*y float32 implementation using `rpt` loop and `f32v2axpy` + * + * Compatible with IPU hardware and IPU model. + * 30% slower than inline assembly implementation. + */ +template = true> +inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, + float2 *z, rptsize_t nblocks) { // Necessary if using unsigned `nblocks`. // __builtin_assume(nblocks < 4096); using T2 = float2; @@ -57,6 +63,71 @@ inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, } } } +/** + * @brief z = a*x + b*y float32 implementation fully optimized in inline + * assembly. + */ +template = true> +inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, + float2 *z, rptsize_t nblocks) { + // Necessary if using unsigned `nblocks`. + // __builtin_assume(nblocks < 4096); + using T2 = float2; + // Using TAS register for the scalar `b`. + __ipu_and_ipumodel_tas tas; + tas.put(b); + + T2 av = {a, a}; + // Explicit variables passed to inline assembly. + // Easier to read + compiling on IPU model. + T2 xv, yv, zv; + uint2 tapaddr; + // Inline assembly loop in order to use `ldst64pace` instruction. + // Note: requires "unrolling" the beginning of the `f32v2axpy` pipeline. + // TODO: investigate issue with inputs register re-use. + asm volatile( + R"( + ld64step %[xv], $m15, %[xptr]+=, 1 + ld64step %[yv], $m15, %[yptr]+=, 1 + { + ld64step %[xv], $m15, %[xptr]+=, 1 + f32v2mul %[zv], %[xv], %[av] + } + { + ld64step %[yv], $m15, %[yptr]+=, 1 + f32v2axpy %[zv], %[yv], %[zv] + } + { + ld64step %[xv], $m15, %[xptr]+=, 1 + f32v2mul %[zv], %[xv], %[av] + } + { + ld64step %[yv], $m15, %[yptr]+=, 1 + f32v2axpy %[zv], %[yv], %[zv] + } + tapack %[tapaddr], %[xptr], $mzero, %[zptr] + .align 8 + { + rpt %[nb], 1 + fnop + } + { + ldst64pace %[xv], %[zv], %[tapaddr]+=, $mzero, 0 + f32v2mul %[zv], %[xv], %[av] + } + { + ld64step %[yv], $m15, %[yptr]+=, 1 + f32v2axpy %[zv], %[yv], %[zv] + } + )" + : [ xptr ] "+r"(x), [ yptr ] "+r"(y), [ av ] "+r"(av), [ xv ] "=r"(xv), + [ yv ] "=r"(yv), [ zv ] "=r"(zv), [ tapaddr ] "+r"(tapaddr), + [ nb ] "+r"(nblocks) + : [ zptr ] "r"(z) + :); + // Note: explicit list of used registers not compiling on IPU model. + // : "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5" +} /** * @brief Apply 2d rotation transform (float). @@ -64,18 +135,11 @@ inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, * Note: input rows are separated, allowing more flexibility * for functions/vertices using this base compute method. */ +template inline void rotation2d_f32(float2 cs, const float2 *inrow0, const float2 *inrow1, float2 *outrow0, - float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) { - axplusby_f32_v1(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); - axplusby_f32_v1(cs[1], cs[0], inrow0, inrow1, outrow1, nblocks); -} - -inline void rotation2d_f32(float2 cs, const float2 *inrow0, - const float2 *inrow1, float2 *outrow0, - float2 *outrow1, rptsize_t nblocks, - ipu::HardwareTag) { - // Using same implementation as IPU model for now. - rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks, - ipu::ModelTag{}); + float2 *outrow1, rptsize_t nblocks) { + // TODO: investigate using IPU AMP unit? + axplusby_f32(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); + axplusby_f32(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks); } diff --git a/tests/lax/test_tile_lax_small_dot.py b/tests/lax/test_tile_lax_small_dot.py index e21cdf3..5076d30 100644 --- a/tests/lax/test_tile_lax_small_dot.py +++ b/tests/lax/test_tile_lax_small_dot.py @@ -45,5 +45,5 @@ def compute_fn(cs, row0, row1): # Hardware cycle count bound. start, end = np.asarray(start)[0], np.asarray(end)[0] hw_cycle_count = end[0] - start[0] - # Observe on IPU Mk2 hw ~1916 cycles. - assert hw_cycle_count <= 2000 + # Observe on IPU Mk2 hw ~1436 cycles. + assert hw_cycle_count <= 1500