Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hessenberg: Use fori_loop() instead of python for-loop #43

Merged
merged 7 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"

using namespace poplar;

/* popc -O2 -I tessellate/tile/vertex\
tessellate/tile/vertex/tile_qr_vertex.cpp \
-o tessellate/tile/vertex/tile_qr_vertex.gp
*/
static constexpr size_t MIN_ALIGN = 8;

/*
The code here is just a minor modification of tile_qr_vertex.cpp
*/

/**
* @brief Vertex computing the correction vector in the Hessenberg algorithm.
*/
class HessenbergCorrectionVectorVertex : public MultiVertex {
public:
using T = float;
Input<Vector<T, poplar::VectorLayout::SPAN, MIN_ALIGN>> Rcol; // (N,) R column.
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> sdiag; // (N,) R diag. sign.
Input<Vector<int, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
cidx;

Output<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
v; // (N,) QR correction vector (not normalized)
Output<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
vrescale; // (1,) QR correction vector rescaling (2 / norm)


// Use static variables as easy sync. mechanisms between worker threads.
// see: https://graphcore.slack.com/archives/C013LPHPX61/p1647443852989259
static T shared_partial_sqnorms[6];

HessenbergCorrectionVectorVertex();

bool compute(unsigned wid) {
const unsigned num_workers = 6;
const unsigned step = num_workers * 2;
const unsigned size = Rcol.size();

const unsigned col_idx = cidx[0];

const unsigned col_idx_rem = col_idx % 2;
const unsigned col_idx_mrem = col_idx - col_idx_rem;
const unsigned col_idx_prem = col_idx + col_idx_rem;

const T initial_rcol_val = Rcol[col_idx];
const T initial_rcol_val_sq = initial_rcol_val * initial_rcol_val;

// FIRST: SET SHARED STATE between workers.
shared_partial_sqnorms[wid] = -1;

const float2 zeros_f2{0, 0};
float2* ptr_outdata_f2 = reinterpret_cast<float2*>(v.data()) + wid;
// Push to col_idx_prem, may write one zero too much, but does not matter!
float2* ptr_outdata_end_f2 = reinterpret_cast<float2*>(&v[col_idx_prem]);
// First chunk of v initialized with zeros.
while (ptr_outdata_f2 < ptr_outdata_end_f2) {
ipu::store_postinc(&ptr_outdata_f2, zeros_f2, num_workers);
}

float2 partials_f2{0, 0};
const float2* ptr_indata_f2 =
reinterpret_cast<const float2*>(&Rcol[col_idx_prem]) + wid;
ptr_outdata_f2 = reinterpret_cast<float2*>(&v[col_idx_prem]) + wid;
ptr_outdata_end_f2 = reinterpret_cast<float2*>(&v[size]);
// Copy Rcol data and accumulate squared norm.
while (ptr_outdata_f2 < ptr_outdata_end_f2) {
const float2 v = ipu::load_postinc(&ptr_indata_f2, num_workers);
partials_f2 += v * v;
ipu::store_postinc(&ptr_outdata_f2, v, num_workers);
}
T partial = partials_f2[0] + partials_f2[1];

// GLOBAL STATE shared by all workers.
shared_partial_sqnorms[wid] = partial;

if (wid == 0) {
// Special case of odd R column index.
// On thread 0: correction to squared normed depending on `col_idx_rem`
T norm_squared = partial + col_idx_rem * initial_rcol_val_sq;
// Accumulate & wait.
for (unsigned w = 1; w < num_workers; ++w) {
// Avoid compiler optimizer with volatile pointer.
volatile T* ptr_partial = &shared_partial_sqnorms[w];
while (*ptr_partial < 0) {
}
norm_squared += shared_partial_sqnorms[w];
}

// Compute the norm.
const T norm = std::sqrt(norm_squared);
// Change the entry of v that corresponds to the diagonal element of R.
const auto update_vidx_val = initial_rcol_val - norm * sdiag[col_idx];
// Re-writing the full new value is faster than updating.
v[col_idx] = update_vidx_val;

// Update the squared norm of v.
norm_squared -= initial_rcol_val_sq;
norm_squared += update_vidx_val * update_vidx_val;

// Vector rescaling for QR householder update.
vrescale[0] = T(2) / norm_squared;
}
return true;
}
};

float HessenbergCorrectionVectorVertex::shared_partial_sqnorms[6] = {-1};

/**
* @brief Vertex implementing the inplace householder (row) update in the QR
* algorithm. NOTE: the vertex is only updating the sub-slice of x corresponding
* to v.
*
* More specifically: x[end-len(v)+i] -= scale1[0] * scale2[0] * v[i]
*
* NOTE: poplar::constraint here to make sure x and v are not part of the same
* memory bank, allowing simultaneous loads (see `ld2x64pace` instruction).
*/
class [[poplar::constraint(
"elem(*x) != elem(*v)")]] HessenbergHouseholderRowUpdateVertex
: public MultiVertex {
public:
using T = float;
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

InOut<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> x; // (N,) row of Q or R
Input<Vector<T, poplar::VectorLayout::SPAN, 8>>
v; // (M,) v correction vector

// Passing 2 scaling factors is more efficient for the QR implementation.
// Avoids another full pass on the v vector in the vertex it is constructed.
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
scale1; // (1,) first scaling factor.
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
scale2; // (1,) 2nd scaling factor.
Input<Vector<int, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
start_idx_;

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
worker_offsets; // (7,) threads work size + 1.



bool compute(unsigned wid) {
// Always assuming size % 2 == 0
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

IndexType start_idx = start_idx_[0]; // X start idx. Must be a multiple of 4 (for bank
// alignment aspects).

// Set the $TAS register with the proper scale.
const T s = -scale1[0] * scale2[0];
// __builtin_ipu_put_tas(s);
__ipu_and_ipumodel_tas tas;
tas.put(s);

// Nothing to do in this worker thread.
if (wstart == wend) {
return true;
}
// X and v IO pointers.
const float2* ptr_inxdata_f2 =
reinterpret_cast<const float2*>(&x[start_idx]) + wstart;
float2* ptr_outxdata_f2 = reinterpret_cast<float2*>(&x[start_idx]) + wstart;
const float2* ptr_vdata_f2 =
reinterpret_cast<const float2*>(&v[0]) + wstart;

float2 xin, vin, rtmp, rout;
// First vectors loading.
xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step);
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step);
// TODO: use ld2x64pace + tapack instructions.
for (IndexType idx = 1; idx != wsize; ++idx) {
rtmp = tas.f32v2axpy(xin, vin);
// rtmp = __builtin_ipu_f32v2axpy(xin, vin);
// Grouping here seems to help the compiler optimising loads?
xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step);
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step);
rout = tas.f32v2axpy(rtmp, rtmp);
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step);
}
// Finish the loop, getting the last computation.
// rtmp = __builtin_ipu_f32v2axpy(xin, vin);
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
rtmp = tas.f32v2axpy(xin, vin);
rout = tas.f32v2axpy(rtmp, rtmp);
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step);

return true;
}
};
Loading