Skip to content

Commit

Permalink
Small matmuls IPU vertices.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 12, 2023
1 parent da44969 commit b266fb1
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 4 deletions.
35 changes: 31 additions & 4 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,27 @@
// #define ALWAYS_INLINE __attribute__((always_inline))
#define ALWAYS_INLINE inline

/**
* Tag dispatching, between IPU model and IPU hardware implementations.
*
* Making it hopefully easier to maintain IPU hardware and model
* implementations, without #ifdef/#endif preprocessor spaghetti code.
*/
namespace ipu {
/** IPU hardware tag. */
struct HardwareTag {};
/** IPU model tag. */
struct ModelTag {};
} // namespace ipu

// IPU dispatch tag preprocessor.
#ifdef __IPU__
#define IPU_DISPATCH_TAG (ipu::HardwareTag{})
#else
#define IPU_DISPATCH_TAG (ipu::ModelTag{})
#endif

#ifdef __IPU__
/**
* @brief Efficient division by 6, on IPU hardware. Up to 98,304.
*/
Expand All @@ -43,7 +62,7 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept {
R"l( uput $TAS, %[sv]
)l"
:
: [sv] "r"(v)
: [ sv ] "r"(v)
:);
}

Expand All @@ -55,7 +74,7 @@ ALWAYS_INLINE void __builtin_ipu_f32v2cmac(float2 x, float2 y) noexcept {
R"l( f32v2mac %[x], %[y]
)l"
:
: [x] "r"(x), [y] "r"(y)
: [ x ] "r"(x), [ y ] "r"(y)
:);
}

Expand All @@ -66,8 +85,8 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) {
asm volatile(
R"l( ld32 %[result], %[address], %[offset]
)l"
: [result] "=r"(result)
: [address] "r"(address), [offset] "r"(offset)
: [ result ] "=r"(result)
: [ address ] "r"(address), [ offset ] "r"(offset)
:);
return result;
}
Expand Down Expand Up @@ -171,3 +190,11 @@ T __builtin_ipu_f32v2axpy(T const& x, T const& y) {
// clang-format on

#endif

/**
* @brief Bitwise cast to a different type.
*/
template <class R, class T>
R as(T x) {
return *reinterpret_cast<R*>(&x);
}
4 changes: 4 additions & 0 deletions tessellate_ipu/core/vertex/ipu_model_types.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#pragma once

// Only defined on IPU model.
#ifndef __IPU__
#include <array>
#include <cstddef>
Expand Down Expand Up @@ -77,4 +78,7 @@ using uint4 = IpuVector<unsigned int, 4>;
using long2 = IpuVector<long, 2>;
using long4 = IpuVector<long, 4>;

// rptsize_t alias on IPU model.
using rptsize_t = uint16_t;

#endif
53 changes: 53 additions & 0 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#include "tile_small_dot.hpp"

#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>

using namespace poplar;

/**
* @brief 2d rotation vertex.
*/
class Rotation2dVertex : public MultiVertex {
public:
using T = float;
using T2 = float2;
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

static constexpr size_t MIN_ALIGN = 8;

Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
cs; // (2,) rotation cosinus/sinus values
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
inrow0; // (N,) first input row vector
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
inrow1; // (N,) second input row vector

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>>
worker_offsets; // (7,) number threads + 1.

Output<Vector<T, poplar::VectorLayout::ONE_PTR>>
outrow0; // (N,) first input row vector
Output<Vector<T, poplar::VectorLayout::ONE_PTR>>
outrow1; // (N,) first input row vector

bool compute(unsigned wid) {
// vectorized offsets.
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

// Vertex inputs/outputs assuring proper alignment.
const T2* inrow0_ptr = reinterpret_cast<const T2*>(inrow0.data()) + wstart;
const T2* inrow1_ptr = reinterpret_cast<const T2*>(inrow1.data()) + wstart;
const T2* cs_ptr = reinterpret_cast<const T2*>(cs.data());
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr,
wsize, IPU_DISPATCH_TAG);
return true;
}
};
81 changes: 81 additions & 0 deletions tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#include "intrinsics_utils.hpp"

/**
* @brief z = a*x + b*y float32 implementation.
*
* where x, y, z are 1D arrays and a, b are scalars.
* Implementation compatible with IPU model and hardware.
*
* Requires input arrays with size % 2 == 0
*/
inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
using T2 = float2;
const T2 av = {a, a};
const T2 bv = {b, b};
// Sub-optimal vectorized implementation.
for (unsigned idx = 0; idx < nblocks; ++idx) {
const T2 xv = ipu::load_postinc(&x, 1);
const T2 yv = ipu::load_postinc(&y, 1);
const T2 zv = av * xv + bv * yv;
ipu::store_postinc(&z, zv, 1);
}
}

void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
const T2 av = {a, a};
// Using TAS register for one of the scalar.
__ipu_and_ipumodel_tas tas;
tas.put(b);

T2 res, xv, yv, zv, tmp;

xv = ipu::load_postinc(&x, 1);
yv = ipu::load_postinc(&y, 1);
res = xv * av;
for (unsigned idx = 0; idx != nblocks; ++idx) {
// Pseudo dual-issuing of instructions.
// popc should be able to generate an optimal rpt loop.
{
xv = ipu::load_postinc(&x, 1);
// TODO: fix ordering of arguments in `f32v2axpy`.
tmp = tas.f32v2axpy(res, yv);
}
{
yv = ipu::load_postinc(&y, 1);
// TODO: fix ordering of arguments in `f32v2axpy`.
zv = tas.f32v2axpy(tmp, tmp);
}
{
ipu::store_postinc(&z, zv, 1);
res = xv * av;
}
}
}

/**
* @brief Apply 2d rotation transform (float).
*
* Note: input rows are separated, allowing more flexibility
* for functions/vertices using this base compute method.
*/
inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) {
axplusby_f32_v1(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32_v1(cs[1], cs[0], inrow0, inrow1, outrow1, nblocks);
}

inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks,
ipu::HardwareTag) {
// Using same implementation as IPU model for now.
rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks,
ipu::ModelTag{});
}
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .tile_lax_dot import IpuConvVertexType
from .tile_lax_gather import gather_p
from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p
from .tile_lax_small_dot import rotation2d_p
from .tile_lax_unary import ( # tanh_inplace_p,
abs_inplace_p,
asin_inplace_p,
Expand Down
46 changes: 46 additions & 0 deletions tessellate_ipu/lax/tile_lax_small_dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import os
from typing import Any, Dict

import numpy as np
from jax.core import ShapedArray

from tessellate_ipu.core import declare_ipu_tile_primitive
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets


def get_small_dot_vertex_gp_filename() -> str:
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_small_dot.cpp")


@declare_ipu_tile_primitive("Rotation2dVertex", gp_filename=get_small_dot_vertex_gp_filename())
def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray):
"""2d rotation apply primitive.
Specific optimization on IPU backend compared to `dot_general_p` primitive.
In particular, allows passing the 2 rows of the (2, N) input as separate arrays (in some
applications, contiguous storage may not be possible).
Args:
cs: Cos/sin 2d rotation entries.
inrow0: First row (N,)
inrow1: Second row (N,)
Returns:
outrow0: First output row (N,)
outrow1: Second output row (N,)
"""
N = inrow0.size
assert N % 2 == 0
assert inrow0 == inrow1
assert cs.dtype == inrow0.dtype
assert cs.dtype == inrow1.dtype
assert inrow0.dtype == np.float32

outputs = {
"outrow0": inrow0,
"outrow1": inrow1,
}
constants = {"worker_offsets": make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)}
temps: Dict[str, Any] = {}
perf_estimate = 100
return outputs, constants, temps, perf_estimate
49 changes: 49 additions & 0 deletions tests/lax/test_tile_lax_small_dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt
import pytest

from tessellate_ipu import ipu_cycle_count, tile_map, tile_put_replicated
from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p


@pytest.mark.ipu_hardware
class IpuTileRotation2dHwTests(chex.TestCase):
def setUp(self):
super().setUp()
np.random.seed(42)

def test__tile_map__rotation2d_primitive__proper_result_and_cycle_count(self):
N = 512
tiles = (0,)
indata = np.random.randn(2, N).astype(np.float32)
cs = np.random.randn(2).astype(np.float32)
rot2d = np.array([[cs[0], -cs[1]], [cs[1], cs[0]]]).astype(np.float32)

def compute_fn(cs, row0, row1):
cs = tile_put_replicated(cs, tiles)
row0 = tile_put_replicated(row0, tiles)
row1 = tile_put_replicated(row1, tiles)
# Benchmark the raw 2d rotation vertex.
cs, row0, row1, start = ipu_cycle_count(cs, row0, row1)
outrow0, outrow1 = tile_map(rotation2d_p, cs, row0, row1) # type:ignore
outrow0, outrow1, end = ipu_cycle_count(outrow0, outrow1)

return outrow0, outrow1, start, end

compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)
outrow0, outrow1, start, end = compute_fn_ipu(cs, indata[0], indata[1])

# Checking getting the proper result!
expected_out = rot2d @ indata
npt.assert_array_almost_equal(np.ravel(outrow0), expected_out[0], decimal=6)
npt.assert_array_almost_equal(np.ravel(outrow1), expected_out[1], decimal=6)
# Hardware cycle count bound.
start, end = np.asarray(start)[0], np.asarray(end)[0]
hw_cycle_count = end[0] - start[0]
# Observe on IPU Mk2 hw ~1916 cycles.
assert hw_cycle_count <= 2000

0 comments on commit b266fb1

Please sign in to comment.