Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional optimization to axplusby_f32 using inline assembly. #47

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@
*/
namespace ipu {
/** IPU hardware tag. */
struct HardwareTag {};
struct HardwareTag {
static constexpr bool hardware = true;
};
/** IPU model tag. */
struct ModelTag {};
struct ModelTag {
static constexpr bool model = true;
};
} // namespace ipu

// IPU dispatch tag preprocessor.
#ifdef __IPU__
#define IPU_TAG_TYPE ipu::HardwareTag
#define IPU_DISPATCH_TAG (ipu::HardwareTag{})
#else
#define IPU_TAG_TYPE ipu::ModelTag
#define IPU_DISPATCH_TAG (ipu::ModelTag{})
#endif

Expand Down
8 changes: 5 additions & 3 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ using namespace poplar;
/**
* @brief 2d rotation vertex.
*/
class Rotation2dVertex : public MultiVertex {
class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)",
"elem(*inrow1) != elem(*outrow1)")]] Rotation2dVertex
: public MultiVertex {
public:
using T = float;
using T2 = float2;
Expand Down Expand Up @@ -46,8 +48,8 @@ class Rotation2dVertex : public MultiVertex {
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr,
wsize, IPU_DISPATCH_TAG);
rotation2d_f32<IPU_TAG_TYPE>(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr,
outrow1_ptr, wsize);
return true;
}
};
94 changes: 79 additions & 15 deletions tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y,
ipu::store_postinc(&z, zv, 1);
}
}

inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
/**
* @brief z = a*x + b*y float32 implementation using `rpt` loop and `f32v2axpy`
*
* Compatible with IPU hardware and IPU model.
* 30% slower than inline assembly implementation.
*/
template <class IpuTag, std::enable_if_t<IpuTag::model, bool> = true>
inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
Expand Down Expand Up @@ -57,25 +63,83 @@ inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
}
}
}
/**
* @brief z = a*x + b*y float32 implementation fully optimized in inline
* assembly.
*/
template <class IpuTag, std::enable_if_t<IpuTag::hardware, bool> = true>
inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
// Using TAS register for the scalar `b`.
__ipu_and_ipumodel_tas tas;
tas.put(b);

T2 av = {a, a};
// Explicit variables passed to inline assembly.
// Easier to read + compiling on IPU model.
T2 xv, yv, zv;
uint2 tapaddr;
// Inline assembly loop in order to use `ldst64pace` instruction.
// Note: requires "unrolling" the beginning of the `f32v2axpy` pipeline.
// TODO: investigate issue with inputs register re-use.
asm volatile(
R"(
ld64step %[xv], $m15, %[xptr]+=, 1
ld64step %[yv], $m15, %[yptr]+=, 1
{
ld64step %[xv], $m15, %[xptr]+=, 1
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
{
ld64step %[xv], $m15, %[xptr]+=, 1
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
tapack %[tapaddr], %[xptr], $mzero, %[zptr]
.align 8
{
rpt %[nb], 1
fnop
}
{
ldst64pace %[xv], %[zv], %[tapaddr]+=, $mzero, 0
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
)"
: [ xptr ] "+r"(x), [ yptr ] "+r"(y), [ av ] "+r"(av), [ xv ] "=r"(xv),
[ yv ] "=r"(yv), [ zv ] "=r"(zv), [ tapaddr ] "+r"(tapaddr),
[ nb ] "+r"(nblocks)
: [ zptr ] "r"(z)
:);
// Note: explicit list of used registers not compiling on IPU model.
// : "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5"
}

/**
* @brief Apply 2d rotation transform (float).
*
* Note: input rows are separated, allowing more flexibility
* for functions/vertices using this base compute method.
*/
template <class IpuTag>
inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) {
axplusby_f32_v1(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32_v1(cs[1], cs[0], inrow0, inrow1, outrow1, nblocks);
}

inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks,
ipu::HardwareTag) {
// Using same implementation as IPU model for now.
rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks,
ipu::ModelTag{});
float2 *outrow1, rptsize_t nblocks) {
// TODO: investigate using IPU AMP unit?
axplusby_f32<IpuTag>(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32<IpuTag>(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks);
}
4 changes: 2 additions & 2 deletions tests/lax/test_tile_lax_small_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def compute_fn(cs, row0, row1):
# Hardware cycle count bound.
start, end = np.asarray(start)[0], np.asarray(end)[0]
hw_cycle_count = end[0] - start[0]
# Observe on IPU Mk2 hw ~1916 cycles.
assert hw_cycle_count <= 2000
# Observe on IPU Mk2 hw ~1436 cycles.
assert hw_cycle_count <= 1500