Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 11, 2023
1 parent 702c51d commit faa8eea
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 29 deletions.
27 changes: 16 additions & 11 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@
/**
* Tag dispatching, between IPU model and IPU hardware implementations.
*
* Making it hopefully easier to maintain IPU hardware and model implementations,
* without #ifdef/#endif preprocessor spaghetti code.
* Making it hopefully easier to maintain IPU hardware and model
* implementations, without #ifdef/#endif preprocessor spaghetti code.
*/
namespace ipu
{
namespace ipu {
/** IPU hardware tag. */
struct HardwareTag{};
struct HardwareTag {};
/** IPU model tag. */
struct ModelTag{};
}
struct ModelTag {};
} // namespace ipu

// IPU dispatch tag preprocessor.
#ifdef __IPU__
#define IPU_DISPATCH_TAG (ipu::HardwareTag{})
#else
#define IPU_DISPATCH_TAG (ipu::ModelTag{})
#endif

#ifdef __IPU__
/**
* @brief Efficient division by 6, on IPU hardware. Up to 98,304.
*/
Expand All @@ -57,7 +62,7 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept {
R"l( uput $TAS, %[sv]
)l"
:
: [sv] "r"(v)
: [ sv ] "r"(v)
:);
}

Expand All @@ -69,7 +74,7 @@ ALWAYS_INLINE void __builtin_ipu_f32v2cmac(float2 x, float2 y) noexcept {
R"l( f32v2mac %[x], %[y]
)l"
:
: [x] "r"(x), [y] "r"(y)
: [ x ] "r"(x), [ y ] "r"(y)
:);
}

Expand All @@ -80,8 +85,8 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) {
asm volatile(
R"l( ld32 %[result], %[address], %[offset]
)l"
: [result] "=r"(result)
: [address] "r"(address), [offset] "r"(offset)
: [ result ] "=r"(result)
: [ address ] "r"(address), [ offset ] "r"(offset)
:);
return result;
}
Expand Down
4 changes: 4 additions & 0 deletions tessellate_ipu/core/vertex/ipu_model_types.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#pragma once

// Only defined on IPU model.
#ifndef __IPU__
#include <array>
#include <cstddef>
Expand Down Expand Up @@ -77,4 +78,7 @@ using uint4 = IpuVector<unsigned int, 4>;
using long2 = IpuVector<long, 2>;
using long4 = IpuVector<long, 4>;

// rptsize_t alias.
using rptsize_t = uint16_t;

#endif
12 changes: 11 additions & 1 deletion tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>

using namespace poplar;

// class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dVertex
// : public MultiVertex {
class Rotation2dVertex : public MultiVertex {
Expand All @@ -22,7 +24,6 @@ class Rotation2dVertex : public MultiVertex {
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
inrow1; // (N,) second input row vector


Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>>
worker_offsets; // (7,) number threads + 1.

Expand All @@ -37,6 +38,15 @@ class Rotation2dVertex : public MultiVertex {
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

const T2* inrow0_ptr = reinterpret_cast<const T2*>(inrow0.data()) + wstart;
const T2* inrow1_ptr = reinterpret_cast<const T2*>(inrow1.data()) + wstart;
const T2* cs_ptr = reinterpret_cast<const T2*>(cs.data());

T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

rotation2_float(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr,
wsize, IPU_DISPATCH_TAG);

return true;
}
Expand Down
47 changes: 35 additions & 12 deletions tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,38 @@ R as(T x) {
return *reinterpret_cast<R *>(&x);
}


inline void rotation2_float(const float2 *inrow0, const float2 *inrow1,
float2 cs, float2 *outrow0, float2 *outrow1,
rptsize_t nblocks, ipu::HardwareTag) {}


inline void rotation2_float(const float2 *inrow0, const float2 *inrow1,
float2 cs, float2 *outrow0, float2 *outrow1,
rptsize_t nblocks, ipu::ModelTag) {


}
/**
* @brief Apply 2d rotation transform.
*/
inline void rotation2_float(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks,
ipu::HardwareTag) {
using T2 = float2;
for (rptsize_t idx = 0; idx != nblocks; ++idx) {
const T2 invec0 = ipu::load_postinc(&inrow0, 1);
const T2 invec1 = ipu::load_postinc(&inrow1, 1);
// Vectorized apply 2d rotation.
const T2 outvec0 = cs[0] * invec0 - cs[1] * invec1;
const T2 outvec1 = cs[1] * invec0 + cs[0] * invec1;
ipu::store_postinc(&outrow0, outvec0, 1);
ipu::store_postinc(&outrow1, outvec1, 1);
}
}
inline void rotation2_float(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) {
using T2 = float2;
const T2 cvec = {cs[0], cs[0]};
const T2 svec = {cs[1], cs[1]};
// Note: supported on IPU model!
for (rptsize_t idx = 0; idx != nblocks; ++idx) {
const T2 invec0 = ipu::load_postinc(&inrow0, 1);
const T2 invec1 = ipu::load_postinc(&inrow1, 1);
// Vectorized apply 2d rotation.
const T2 outvec0 = cvec * invec0 - svec * invec1;
const T2 outvec1 = svec * invec0 + cvec * invec1;
ipu::store_postinc(&outrow0, outvec0, 1);
ipu::store_postinc(&outrow1, outvec1, 1);
}
}
5 changes: 2 additions & 3 deletions tessellate_ipu/lax/tile_lax_small_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray):
outrow1: Second output row (N,)
"""
N = inrow0.size
assert N % 4 == 0
assert inrow0 == inrow1
assert cs.dtype == inrow0.dtype
assert cs.dtype == inrow1.dtype
Expand All @@ -41,9 +42,7 @@ def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray):
"outrow0": inrow0,
"outrow1": inrow1,
}
constants = {
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)
}
constants = {"worker_offsets": make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)}
temps: Dict[str, Any] = {}
perf_estimate = 100
return outputs, constants, temps, perf_estimate
39 changes: 37 additions & 2 deletions tests/lax/test_tile_lax_small_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,43 @@
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
import pytest

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu import tile_map, tile_put_replicated
from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p


@pytest.mark.ipu_hardware
class IpuTileRotation2dHwTests(chex.TestCase):
def setUp(self):
super().setUp()
np.random.seed(42)

def test__tile_map__rotation2d_primitive__proper_result(self):
N = 8
tiles = (0,)
indata = np.random.randn(2, N).astype(np.float32)
cs = np.random.randn(2).astype(np.float32)
rot2d = np.array([[cs[0], -cs[1]], [cs[1], cs[0]]]).astype(np.float32)

def compute_fn(cs, row0, row1):
cs = tile_put_replicated(cs, tiles)
row0 = tile_put_replicated(row0, tiles)
row1 = tile_put_replicated(row1, tiles)
return tile_map(rotation2d_p, cs, row0, row1)

# compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)

# output_cpu = compute_fn_cpu(indata)
outrow0, outrow1 = compute_fn_ipu(cs, indata[0], indata[1])
expected_out = rot2d @ indata

print(outrow0)
print(outrow1)

# npt.assert_array_almost_equal(np.ravel(outrow0), indata[0], decimal=5)
# npt.assert_array_almost_equal(np.ravel(outrow1), indata[1], decimal=5)

npt.assert_array_almost_equal(np.ravel(outrow0), expected_out[0], decimal=5)
npt.assert_array_almost_equal(np.ravel(outrow1), expected_out[1], decimal=5)

0 comments on commit faa8eea

Please sign in to comment.