From f60fcac503569099a1b37102234d87d6bee506a6 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 12 Oct 2023 20:46:36 +0000 Subject: [PATCH] wip --- .../core/vertex/intrinsics_utils.hpp | 10 ++- tessellate_ipu/core/vertex/tile_small_dot.cpp | 4 +- tessellate_ipu/core/vertex/tile_small_dot.hpp | 80 +++++++++++-------- tests/lax/test_tile_lax_small_dot.py | 2 +- 4 files changed, 58 insertions(+), 38 deletions(-) diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index 5ea61e8..a254777 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -32,15 +32,21 @@ */ namespace ipu { /** IPU hardware tag. */ -struct HardwareTag {}; +struct HardwareTag { + static constexpr bool hardware = true; +}; /** IPU model tag. */ -struct ModelTag {}; +struct ModelTag { + static constexpr bool model = true; +}; } // namespace ipu // IPU dispatch tag preprocessor. #ifdef __IPU__ +#define IPU_TAG_TYPE ipu::HardwareTag #define IPU_DISPATCH_TAG (ipu::HardwareTag{}) #else +#define IPU_TAG_TYPE ipu::ModelTag #define IPU_DISPATCH_TAG (ipu::ModelTag{}) #endif diff --git a/tessellate_ipu/core/vertex/tile_small_dot.cpp b/tessellate_ipu/core/vertex/tile_small_dot.cpp index fc69af6..d50d2f5 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.cpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.cpp @@ -48,8 +48,8 @@ class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)", T2* outrow0_ptr = reinterpret_cast(outrow0.data()) + wstart; T2* outrow1_ptr = reinterpret_cast(outrow1.data()) + wstart; - rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, - wsize, IPU_DISPATCH_TAG); + rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, + wsize); return true; } }; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index 156a2e3..1ac622e 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -25,7 +25,8 @@ inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y, /** * @brief z = a*x + b*y float32 implementation using `rpt` loop and `f32v2axpy` */ -inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, +template = true> +inline void axplusby_f32_v2(float a, float b, const float2 *x, const float2 *y, float2 *z, rptsize_t nblocks) { // Necessary if using unsigned `nblocks`. // __builtin_assume(nblocks < 4096); @@ -63,56 +64,65 @@ inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, * @brief z = a*x + b*y float32 implementation fully optimized in inline * assembly. */ +template = true> inline void axplusby_f32_v2(float a, float b, const float2 *x, const float2 *y, float2 *z, rptsize_t nblocks) { // Necessary if using unsigned `nblocks`. // __builtin_assume(nblocks < 4096); using T2 = float2; - const T2 av = {a, a}; + T2 av = {a, a}; // Using TAS register for the scalar `b`. __ipu_and_ipumodel_tas tas; tas.put(b); + T2 xv, yv, zv; + uint2 tapaddr; + // Inline assembly loop in order to use `ldst64pace` instruction. // Note: requires "unrolling" the beginning of the `f32v2axpy` pipeline. asm volatile( R"( - ld64step $a0:1, $m15, %[xptr]+=, 1 - ld64step $a4:5, $m15, %[yptr]+=, 1 + ld64step %[xv], $m15, %[xptr]+=, 1 + ld64step %[yv], $m15, %[yptr]+=, 1 { - ld64step $a0:1, $m15, %[xptr]+=, 1 - f32v2mul $a2:3, $a0:1, %[avec] + ld64step %[xv], $m15, %[xptr]+=, 1 + f32v2mul %[zv], %[xv], %[av] } { - ld64step $a4:5, $m15, %[yptr]+=, 1 - f32v2axpy $a2:3, $a4:5, $a2:3 + ld64step %[yv], $m15, %[yptr]+=, 1 + f32v2axpy %[zv], %[yv], %[zv] } { - ld64step $a0:1, $m15, %[xptr]+=, 1 - f32v2mul $a2:3, $a0:1, %[avec] + ld64step %[xv], $m15, %[xptr]+=, 1 + f32v2mul %[zv], %[xv], %[av] } { - ld64step $a4:5, $m15, %[yptr]+=, 1 - f32v2axpy $a2:3, $a4:5, $a2:3 + ld64step %[yv], $m15, %[yptr]+=, 1 + f32v2axpy %[zv], %[yv], %[zv] } - tapack $m4:5, %[xptr], $mzero, %[zptr] + tapack %[tapaddr], %[xptr], $mzero, %[zptr] .align 8 { rpt %[nb], 1 fnop } { - ldst64pace $a0:1, $a2:3, $m4:5+=, $mzero, 0 - f32v2mul $a2:3, $a0:1, %[avec] + ldst64pace %[xv], %[zv], %[tapaddr]+=, $mzero, 0 + f32v2mul %[zv], %[xv], %[av] } { - ld64step $a4:5, $m15, %[yptr]+=, 1 - f32v2axpy $a2:3, $a4:5, $a2:3 + ld64step %[yv], $m15, %[yptr]+=, 1 + f32v2axpy %[zv], %[yv], %[zv] } )" - : [ xptr ] "+r"(x), [ yptr ] "+r"(y) - : [ nb ] "r"(nblocks), [ zptr ] "r"(z), [ avec ] "r"(av) - : "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5"); + : [ xptr ] "+r"(x), [ yptr ] "+r"(y), [ av ] "+r"(av), [ xv ] "=r"(xv), + [ yv ] "=r"(yv), [ zv ] "=r"(zv), [ tapaddr ] "+r"(tapaddr), + [ nb ] "+r"(nblocks) + : [ zptr ] "r"(z) + // : "r0"); + // : "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5"); + // : "$m4", "$m5"); + :); } /** @@ -121,19 +131,23 @@ inline void axplusby_f32_v2(float a, float b, const float2 *x, const float2 *y, * Note: input rows are separated, allowing more flexibility * for functions/vertices using this base compute method. */ -inline void rotation2d_f32(float2 cs, const float2 *inrow0, - const float2 *inrow1, float2 *outrow0, - float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) { +template +void rotation2d_f32(float2 cs, const float2 *inrow0, const float2 *inrow1, + float2 *outrow0, float2 *outrow1, rptsize_t nblocks) { // TODO: investigate using IPU AMP unit? - axplusby_f32_v2(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); - axplusby_f32_v2(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks); + axplusby_f32_v2(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); + axplusby_f32_v2(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks); } -inline void rotation2d_f32(float2 cs, const float2 *inrow0, - const float2 *inrow1, float2 *outrow0, - float2 *outrow1, rptsize_t nblocks, - ipu::HardwareTag) { - // Using same implementation as IPU model for now. - rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks, - ipu::ModelTag{}); -} +// void rotation2d_f32(float2 cs, const float2 *inrow0, +// const float2 *inrow1, float2 *outrow0, +// float2 *outrow1, rptsize_t nblocks, +// ipu::HardwareTag) { +// // Using same implementation as IPU model for now. +// axplusby_f32_v2(cs[0], -cs[1], inrow0, inrow1, outrow0, +// nblocks); axplusby_f32_v2(cs[0], cs[1], inrow1, inrow0, +// outrow1, nblocks); + +// // rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks, +// // ipu::ModelTag{}); +// } diff --git a/tests/lax/test_tile_lax_small_dot.py b/tests/lax/test_tile_lax_small_dot.py index 4cda4ff..dcbf419 100644 --- a/tests/lax/test_tile_lax_small_dot.py +++ b/tests/lax/test_tile_lax_small_dot.py @@ -46,4 +46,4 @@ def compute_fn(cs, row0, row1): start, end = np.asarray(start)[0], np.asarray(end)[0] hw_cycle_count = end[0] - start[0] # Observe on IPU Mk2 hw ~1418 cycles. - assert hw_cycle_count <= 1500 + assert hw_cycle_count <= 150