Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 17, 2023
1 parent 5063716 commit 1ce6eee
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 53 deletions.
10 changes: 10 additions & 0 deletions tessellate_ipu/core/tile_array.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
import itertools
from dataclasses import dataclass
from typing import Any, Sequence, Tuple, Union

import chex
import jax.lax
import numpy as np
from jax.core import ShapedArray
from jax.interpreters.xla import DeviceArray
Expand Down Expand Up @@ -185,6 +187,14 @@ def __getitem__(self, key: Union[SliceType, MultiSliceType]) -> "TileShardedArra
check_tile_array_multi_slice(key, self.array.shape)
return TileShardedArray(array=self.array[key], tiles=self.tiles[key[0]]) # type:ignore

@classmethod
def concatenate(cls, arrays: Sequence["TileShardedArray"]) -> "TileShardedArray":
"""Concatenate tile sharded arrays along the first axis."""
assert all([isinstance(v, TileShardedArray) for v in arrays])
outarray = jax.lax.concatenate([v.array for v in arrays], dimension=0)
outtiles = tuple(itertools.chain(*[v.tiles for v in arrays]))
return TileShardedArray(array=outarray, tiles=outtiles)


def tile_put_sharded(array: DeviceArray, tiles: Sequence[int]) -> TileShardedArray:
"""Shard a JAX array over tiles on the first axis.
Expand Down
35 changes: 26 additions & 9 deletions tessellate_ipu/core/tile_interpreter_vertex_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
import math
from typing import List
from typing import List, Optional

import numpy as np
from numpy.typing import DTypeLike, NDArray
Expand All @@ -26,36 +26,53 @@ def make_num_elements_per_worker(N: int, num_workers: int) -> NDArray[np.int32]:


def make_ipu_vector1d_worker_offsets(
size: int, vector_size: int = 2, num_workers: int = 6, wdtype: DTypeLike = np.uint16
size: int,
vector_size: int = 2,
num_workers: int = 6,
wdtype: DTypeLike = np.uint16,
allow_overlap: bool = False,
grain_size: Optional[int] = None,
) -> NDArray[np.int_]:
"""Make the QR householder row update worker sizes, i.e. how many
"""Make worker sizes/offsets for a 1D array workload, i.e. how many
data vectors per worker thread?
Args:
size: Size of the vector to divide.
vector_size: Vector size (2: float, 4: half).
num_workers: Number of workers.
wdtype: Worklists dtype.
allow_overlap: Allowing overlap between workers. Make it easier to deal with remainer term.
grain_size: Optional grain size. vector_size by default.
Returns:
(6,) number of data vectors per thread.
"""
grain_size = grain_size or vector_size
grain_scale = grain_size // vector_size

def make_offsets_fn(sizes):
sizes = [0] + sizes
offsets = np.cumsum(np.array(sizes, wdtype), dtype=wdtype)
offsets = np.cumsum(np.array(sizes, wdtype) * grain_scale, dtype=wdtype)
return offsets

assert size % vector_size == 0
# TODO: support properly odd size.
assert size % 2 == 0, "Not supporting odd sizing at the moment."
# Base checks!
assert grain_size % vector_size == 0
assert size >= grain_size, f"Requires at least a size of {grain_size}."
assert (
size % grain_size == 0 or allow_overlap
), f"Requires the size, {size}, divisible by the grain size {grain_size}, (or allowing overlap)."

# Base worksize on the first few workers.
base_worksize: int = math.ceil(size / (vector_size * num_workers))
num_base_workers = size // (vector_size * base_worksize)
base_worksize: int = math.ceil(size / (grain_size * num_workers))
num_base_workers = size // (grain_size * base_worksize)
worker_sizes: List[int] = [base_worksize] * num_base_workers
if num_base_workers == num_workers:
return make_offsets_fn(worker_sizes)

# Remainer term, for the next thread.
rem_worksize = size - base_worksize * vector_size * num_base_workers
rem_worksize = rem_worksize // vector_size
rem_worksize = size - base_worksize * grain_size * num_base_workers
rem_worksize = rem_worksize // grain_size
worker_sizes += [rem_worksize]
# Fill the rest with zeros.
unused_workers = num_workers - num_base_workers - 1
Expand Down
15 changes: 15 additions & 0 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ ALWAYS_INLINE T ipu_div_by_6(T n) noexcept {
*/
ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept {
// TAS register, used for __builtin_ipu_f32v2axpy.
// TODO: use `__builtin_ipu_uput`?
asm volatile(
R"l( uput $TAS, %[sv]
)l"
Expand All @@ -72,6 +73,20 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept {
:);
}

/**
* @brief Zero AACC registers.
*/
ALWAYS_INLINE void __builtin_ipu_aacc_zero() {
asm (R"(
setzi $a0, 0x8
uput $FP_CLR, $a0
)"
:
:
: "$a0");
}


/**
* @brief IPU cmac f32 instruction.
*/
Expand Down
125 changes: 125 additions & 0 deletions tessellate_ipu/core/vertex/ipu_amp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#pragma once
#include <type_traits>

#include "intrinsics_utils.hpp"
#include "ipu_model_types.hpp"

namespace ipu {

/**
* @brief Thin abstraction of the IPU AMP unit(s) and registers, allowing
* to write generic code compiling on IPU model and IPU hardware.
*
* NOTE: zero-cost abstraction on IPU hardware.
*
* The AMP class is modelling AACC registers as well as AMP unit instructions
* on the IPU model, reproducing the expected behaviour of the hardware.
*/
template <typename T>
class AMP {
public:
// TODO: support half as well.
static_assert(std::is_same_v<T, float>);
using FPType = T;
/** Number of AACC register available in hw. */
static constexpr unsigned NumAACC = 16;

// TODO: random initialization on IPU model of registers.
AMP() noexcept = default;
// No copy + no move allowed!
AMP(const AMP&) = delete;
AMP(AMP&&) = delete;

/**
* @brief Set the value of the TAS register, used in
* `axpy` operation.
*/
ALWAYS_INLINE void tas(FPType val) noexcept {
#ifdef __IPU__
__builtin_ipu_put_tas(val);
#else
m_tas = val;
#endif
}
/**
* @brief Zero AACC registers.
*/
ALWAYS_INLINE void aaccZero() noexcept {
#ifdef __IPU__
__builtin_ipu_aacc_zero();
#else
for (unsigned idx = 0; idx < NumAACC; ++idx) {
m_aacc[idx] = 0;
}
#endif
}

/**
* @brief Scaled-add `axpy` intrinsic. Only supported on FP32.
* NOTE: act as 1 stage pipeline, storing result in AACC[0...2]
*/
ALWAYS_INLINE float2 axpy(float2 x, float2 y) noexcept {
using T2 = float2;
#ifdef __IPU__
return __builtin_ipu_f32v2axpy(x, y);
#else
// Simulating pipeline with storing in AACC[0] and AACC[2].
const auto res = T2{m_aacc[0], m_aacc[2]};
// FIXME/TODO: understand ordering!?
m_aacc[0] = m_tas * y[0] + x[0];
m_aacc[2] = m_tas * y[1] + x[1];
return res;
#endif
}

/**
* @brief Outer-product `aop` intrinsic. Only supported on FP32.
* Storing results in AACC[0...6]
*/
void aop(float2 x, float2 y) noexcept {
#ifdef __IPU__
// Note: third argument not used by hw.
__builtin_ipu_f32v2aop(x, y, 0);
#else
// Multiply + accumulate.
m_aacc[0] += x[0] * y[0];
m_aacc[2] += x[1] * y[0];
m_aacc[4] += x[0] * y[1];
m_aacc[6] += x[1] * y[1];
#endif
}

/**
* @brief `gina` instruction: get AACC register + propagate.
* FIXME: support non-zero flag/index.
*/
template <unsigned int FLAG>
float2 gina(float2 val) noexcept {
using T2 = float2;
#ifdef __IPU__
return __builtin_ipu_f32v2gina(val, 0);
#else
// TODO: implement GINA_IMMFLAGS__SET__GET
const auto res = T2{m_aacc[0], m_aacc[2]};
// Propagate accumulator states.
for (unsigned idx = 4; idx < NumAACC; idx += 4) {
m_aacc[idx - 4] = m_aacc[idx];
m_aacc[idx - 2] = m_aacc[idx + 2];
}
m_aacc[NumAACC - 4] = val[0];
m_aacc[NumAACC - 2] = val[1];
return res;
#endif
}

private:
#ifndef __IPU__
// Simulating AACC registers on IPU model.
FPType m_aacc[NumAACC];
// Simulating TAS register on IPU model.
FPType m_tas;
#endif
};

} // namespace ipu
Loading

0 comments on commit 1ce6eee

Please sign in to comment.