-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] Hessenberg: Use fori_loop() instead of python for-loop (#43)
* Using fori_loop() instead of python for-loop * MIN_ALIGN on all vertex Vectors fixes crash * now up to size 1472, experimental corr vec vertex * Using gather_p primitive * Tile sharding handles size > 1472 * tile mapping for N<736 and clean-up * Docstring fix
- Loading branch information
Showing
3 changed files
with
411 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
// Copyright (c) 2022 Graphcore Ltd. All rights reserved. | ||
#include <poplar/HalfFloat.hpp> | ||
#include <poplar/Vertex.hpp> | ||
|
||
#include "intrinsics_utils.hpp" | ||
|
||
using namespace poplar; | ||
|
||
/* popc -O2 -I tessellate/tile/vertex\ | ||
tessellate/tile/vertex/tile_qr_vertex.cpp \ | ||
-o tessellate/tile/vertex/tile_qr_vertex.gp | ||
*/ | ||
static constexpr size_t MIN_ALIGN = 8; | ||
|
||
/* | ||
The code here is just a minor modification of tile_qr_vertex.cpp | ||
*/ | ||
|
||
/** | ||
* @brief Vertex computing the correction vector in the Hessenberg algorithm. | ||
*/ | ||
class HessenbergCorrectionVectorVertex : public MultiVertex { | ||
public: | ||
using T = float; | ||
Input<Vector<T, poplar::VectorLayout::SPAN, MIN_ALIGN>> Rcol; // (N,) R column. | ||
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> sdiag; // (N,) R diag. sign. | ||
Input<Vector<int, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
cidx; | ||
|
||
Output<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
v; // (N,) QR correction vector (not normalized) | ||
Output<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
vrescale; // (1,) QR correction vector rescaling (2 / norm) | ||
|
||
|
||
// Use static variables as easy sync. mechanisms between worker threads. | ||
// see: https://graphcore.slack.com/archives/C013LPHPX61/p1647443852989259 | ||
static T shared_partial_sqnorms[6]; | ||
|
||
HessenbergCorrectionVectorVertex(); | ||
|
||
bool compute(unsigned wid) { | ||
const unsigned num_workers = 6; | ||
const unsigned step = num_workers * 2; | ||
const unsigned size = Rcol.size(); | ||
|
||
const unsigned col_idx = cidx[0]; | ||
|
||
const unsigned col_idx_rem = col_idx % 2; | ||
const unsigned col_idx_mrem = col_idx - col_idx_rem; | ||
const unsigned col_idx_prem = col_idx + col_idx_rem; | ||
|
||
const T initial_rcol_val = Rcol[col_idx]; | ||
const T initial_rcol_val_sq = initial_rcol_val * initial_rcol_val; | ||
|
||
// FIRST: SET SHARED STATE between workers. | ||
shared_partial_sqnorms[wid] = -1; | ||
|
||
const float2 zeros_f2{0, 0}; | ||
float2* ptr_outdata_f2 = reinterpret_cast<float2*>(v.data()) + wid; | ||
// Push to col_idx_prem, may write one zero too much, but does not matter! | ||
float2* ptr_outdata_end_f2 = reinterpret_cast<float2*>(&v[col_idx_prem]); | ||
// First chunk of v initialized with zeros. | ||
while (ptr_outdata_f2 < ptr_outdata_end_f2) { | ||
ipu::store_postinc(&ptr_outdata_f2, zeros_f2, num_workers); | ||
} | ||
|
||
float2 partials_f2{0, 0}; | ||
const float2* ptr_indata_f2 = | ||
reinterpret_cast<const float2*>(&Rcol[col_idx_prem]) + wid; | ||
ptr_outdata_f2 = reinterpret_cast<float2*>(&v[col_idx_prem]) + wid; | ||
ptr_outdata_end_f2 = reinterpret_cast<float2*>(&v[size]); | ||
// Copy Rcol data and accumulate squared norm. | ||
while (ptr_outdata_f2 < ptr_outdata_end_f2) { | ||
const float2 v = ipu::load_postinc(&ptr_indata_f2, num_workers); | ||
partials_f2 += v * v; | ||
ipu::store_postinc(&ptr_outdata_f2, v, num_workers); | ||
} | ||
T partial = partials_f2[0] + partials_f2[1]; | ||
|
||
// GLOBAL STATE shared by all workers. | ||
shared_partial_sqnorms[wid] = partial; | ||
|
||
if (wid == 0) { | ||
// Special case of odd R column index. | ||
// On thread 0: correction to squared normed depending on `col_idx_rem` | ||
T norm_squared = partial + col_idx_rem * initial_rcol_val_sq; | ||
// Accumulate & wait. | ||
for (unsigned w = 1; w < num_workers; ++w) { | ||
// Avoid compiler optimizer with volatile pointer. | ||
volatile T* ptr_partial = &shared_partial_sqnorms[w]; | ||
while (*ptr_partial < 0) { | ||
} | ||
norm_squared += shared_partial_sqnorms[w]; | ||
} | ||
|
||
// Compute the norm. | ||
const T norm = std::sqrt(norm_squared); | ||
// Change the entry of v that corresponds to the diagonal element of R. | ||
const auto update_vidx_val = initial_rcol_val - norm * sdiag[col_idx]; | ||
// Re-writing the full new value is faster than updating. | ||
v[col_idx] = update_vidx_val; | ||
|
||
// Update the squared norm of v. | ||
norm_squared -= initial_rcol_val_sq; | ||
norm_squared += update_vidx_val * update_vidx_val; | ||
|
||
// Vector rescaling for QR householder update. | ||
vrescale[0] = T(2) / norm_squared; | ||
} | ||
return true; | ||
} | ||
}; | ||
|
||
float HessenbergCorrectionVectorVertex::shared_partial_sqnorms[6] = {-1}; | ||
|
||
/** | ||
* @brief Vertex implementing the inplace householder (row) update in the QR | ||
* algorithm. NOTE: the vertex is only updating the sub-slice of x corresponding | ||
* to v. | ||
* | ||
* More specifically: x[end-len(v)+i] -= scale1[0] * scale2[0] * v[i] | ||
* | ||
* NOTE: poplar::constraint here to make sure x and v are not part of the same | ||
* memory bank, allowing simultaneous loads (see `ld2x64pace` instruction). | ||
*/ | ||
class [[poplar::constraint( | ||
"elem(*x) != elem(*v)")]] HessenbergHouseholderRowUpdateVertex | ||
: public MultiVertex { | ||
public: | ||
using T = float; | ||
// Using `uint16` seems to be generating more efficient loops? | ||
using IndexType = unsigned short; | ||
|
||
InOut<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> x; // (N,) row of Q or R | ||
Input<Vector<T, poplar::VectorLayout::SPAN, 8>> | ||
v; // (M,) v correction vector | ||
|
||
// Passing 2 scaling factors is more efficient for the QR implementation. | ||
// Avoids another full pass on the v vector in the vertex it is constructed. | ||
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
scale1; // (1,) first scaling factor. | ||
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
scale2; // (1,) 2nd scaling factor. | ||
Input<Vector<int, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
start_idx_; | ||
|
||
Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
worker_offsets; // (7,) threads work size + 1. | ||
|
||
|
||
|
||
bool compute(unsigned wid) { | ||
// Always assuming size % 2 == 0 | ||
constexpr unsigned ptr_step = 1; | ||
const IndexType wstart = worker_offsets[wid]; | ||
const IndexType wend = worker_offsets[wid + 1]; | ||
const IndexType wsize = wend - wstart; | ||
|
||
IndexType start_idx = start_idx_[0]; // X start idx. Must be a multiple of 4 (for bank | ||
// alignment aspects). | ||
|
||
// Set the $TAS register with the proper scale. | ||
const T s = -scale1[0] * scale2[0]; | ||
// __builtin_ipu_put_tas(s); | ||
__ipu_and_ipumodel_tas tas; | ||
tas.put(s); | ||
|
||
// Nothing to do in this worker thread. | ||
if (wstart == wend) { | ||
return true; | ||
} | ||
// X and v IO pointers. | ||
const float2* ptr_inxdata_f2 = | ||
reinterpret_cast<const float2*>(&x[start_idx]) + wstart; | ||
float2* ptr_outxdata_f2 = reinterpret_cast<float2*>(&x[start_idx]) + wstart; | ||
const float2* ptr_vdata_f2 = | ||
reinterpret_cast<const float2*>(&v[0]) + wstart; | ||
|
||
float2 xin, vin, rtmp, rout; | ||
// First vectors loading. | ||
xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step); | ||
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); | ||
// TODO: use ld2x64pace + tapack instructions. | ||
for (IndexType idx = 1; idx != wsize; ++idx) { | ||
rtmp = tas.f32v2axpy(xin, vin); | ||
// rtmp = __builtin_ipu_f32v2axpy(xin, vin); | ||
// Grouping here seems to help the compiler optimising loads? | ||
xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step); | ||
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); | ||
rout = tas.f32v2axpy(rtmp, rtmp); | ||
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); | ||
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); | ||
} | ||
// Finish the loop, getting the last computation. | ||
// rtmp = __builtin_ipu_f32v2axpy(xin, vin); | ||
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); | ||
rtmp = tas.f32v2axpy(xin, vin); | ||
rout = tas.f32v2axpy(rtmp, rtmp); | ||
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); | ||
|
||
return true; | ||
} | ||
}; |
Oops, something went wrong.