Skip to content

Commit

Permalink
IPU Jacobi eigh fori_loop transition
Browse files Browse the repository at this point in the history
This PR is switching the IPU Jacobi `eigh` implementation from a
Python loop to a JAX `fori_loop`, reducing massively code size and
allowing Eigen decomposition on larger matrices.

Note: in terms of performance, some issues still remain as Poplar
compiler is adding some extra on-tile copies which in theory could
be eluded.
  • Loading branch information
balancap committed Oct 5, 2023
1 parent d5b7c46 commit 736e9b0
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 197 deletions.
264 changes: 165 additions & 99 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,59 @@ class JacobiSymSchur2 : public Vertex {
}
};

template <typename T>
void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
T* qcol_updated, T* cs, unsigned p, unsigned q,
unsigned short wstart,
unsigned short wend) noexcept {
using T2 = float2;
using IndexType = unsigned short;

const T Apq = pcol[q];
const T App = pcol[p];
const T Aqq = qcol[q];

// Schur2 decomposition.
const T2 cs_vec = sym_schur2(App, Aqq, Apq);
const T& c = cs_vec[0];
const T& s = cs_vec[1];
cs[0] = c;
cs[1] = s;

// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const float2* ptr_pcol = reinterpret_cast<const float2*>(pcol) + wstart;
const float2* ptr_qcol = reinterpret_cast<const float2*>(qcol) + wstart;
float2* ptr_pcol_updated = reinterpret_cast<float2*>(pcol_updated) + wstart;
float2* ptr_qcol_updated = reinterpret_cast<float2*>(qcol_updated) + wstart;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};

// Easier to vectorized + parallelize if start with normal update first.
for (IndexType idx = 0; idx != wsize; ++idx) {
// TODO: investigate assembly?
const T2 pvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 qvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 pvec_updated = cvec * pvec - svec * qvec;
const T2 qvec_updated = svec * pvec + cvec * qvec;

ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1);
ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1);
}

// Update main values App, Apq, Aqq
pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq;
qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq;
// Zero on purpose with Schur decomposition!
pcol_updated[q] = 0;
qcol_updated[p] = 0;
}

/**
* @brief Jacobi algorithm, update first step: schur + column update.
*
Expand All @@ -92,78 +145,55 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
rotset; // (2,) rotation index p and q. p < q
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> pcol; // (N,) p column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> qcol; // (N,) q column
// p/q cols + index prefix (2 x uint32).
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> pcol; // (N + 2,) p column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> qcol; // (N + 2,) q column

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>>
worker_offsets; // (7,) threads work size + 1.

Output<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
rotset_sorted; // (3,) rotset index sorted + was sorted?
Output<Vector<T, poplar::VectorLayout::ONE_PTR, 8>>
cs; // (2,) (c, s) Schur decomposition values

Output<Vector<T, poplar::VectorLayout::ONE_PTR, 8>>
pcol_updated; // (N,) p column updated
pcol_updated; // (N + 2,) p column updated
Output<Vector<T, poplar::VectorLayout::ONE_PTR, 8>>
qcol_updated; // (N,) q column updated

const IndexType N; // size
qcol_updated; // (N + 2,) q column updated

JacobiUpdateFirstStep();

bool compute(unsigned wid) {
const unsigned p = rotset[0];
const unsigned q = rotset[1];
const T Apq = pcol[q];
const T App = pcol[p];
const T Aqq = qcol[q];
// Size of the index prefix in pcol and qcol.
constexpr int INDEX_PREFIX = 2;
const unsigned p = *((unsigned*)pcol.data());
const unsigned q = *((unsigned*)qcol.data());

// Schur2 decomposition.
const T2 cs_vec = sym_schur2(App, Aqq, Apq);
const T& c = cs_vec[0];
const T& s = cs_vec[1];
cs[0] = c;
cs[1] = s;

// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const float2* ptr_pcol =
reinterpret_cast<const float2*>(pcol.data()) + wstart;
const float2* ptr_qcol =
reinterpret_cast<const float2*>(qcol.data()) + wstart;
float2* ptr_pcol_updated =
reinterpret_cast<float2*>(pcol_updated.data()) + wstart;
float2* ptr_qcol_updated =
reinterpret_cast<float2*>(qcol_updated.data()) + wstart;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};

// Easier to vectorized + parallelize if start with normal update first.
for (IndexType idx = 0; idx != wsize; ++idx) {
// TODO: investigate assembly?
const T2 pvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 qvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 pvec_updated = cvec * pvec - svec * qvec;
const T2 qvec_updated = svec * pvec + cvec * qvec;

ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1);
ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1);
// Forward p/q indices.
pcol_updated[0] = pcol[0];
qcol_updated[0] = qcol[0];

if (p <= q) {
// Proper ordering of p and q already.
jacob_update_first_step(
pcol.data() + INDEX_PREFIX, qcol.data() + INDEX_PREFIX,
pcol_updated.data() + INDEX_PREFIX,
qcol_updated.data() + INDEX_PREFIX, cs.data(), p, q, wstart, wend);
rotset_sorted[0] = p;
rotset_sorted[1] = q;
} else {
// Swap p and q columns as q < p
jacob_update_first_step(
qcol.data() + INDEX_PREFIX, pcol.data() + INDEX_PREFIX,
qcol_updated.data() + INDEX_PREFIX,
pcol_updated.data() + INDEX_PREFIX, cs.data(), q, p, wstart, wend);
rotset_sorted[0] = q;
rotset_sorted[1] = p;
}

// Update main values App, Apq, Aqq
pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq;
qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq;
// Zero on purpose with Schur decomposition!
pcol_updated[q] = 0;
qcol_updated[p] = 0;
return true;
}
};
Expand All @@ -178,65 +208,104 @@ class JacobiUpdateSecondStep : public MultiVertex {
InOut<Vector<T, poplar::VectorLayout::ONE_PTR, 8>>
cs_arr; // (N/2, 2) (c, s) values
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
rotset_arr; // (N/2, 2) (p, q) array values. p < q
rotset_sorted_arr; // (N/2, 2) (p, q) array values. p < q
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
rotset_idx_ignored; // (1,) index in rotset to ignore.

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>>
worker_offsets; // (7,) threads work size + 1.

Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> pcol; // (N,) p column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> qcol; // (N,) q column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> pcol; // (N+2,) p column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> qcol; // (N+2,) q column

Output<Vector<T, poplar::VectorLayout::ONE_PTR, 8>>
pcol_updated; // (N,) p column updated
pcol_updated; // (N+2,) p column updated
Output<Vector<T, poplar::VectorLayout::ONE_PTR, 8>>
qcol_updated; // (N,) q column updated

// const unsigned ignore_idx; // cs/pq index to ignore.
const IndexType halfN; // N / 2
qcol_updated; // (N+2,) q column updated

JacobiUpdateSecondStep();

bool compute(unsigned wid) {
// Use (p, q) = (1, 0) for ignore idx.
const unsigned ignore_idx = 2 * rotset_idx_ignored[0];
cs_arr[ignore_idx] = 1;
cs_arr[ignore_idx + 1] = 0;

// Size of the index prefix in pcol and qcol.
constexpr int INDEX_PREFIX = 2;
// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

// Use (p, q) = (1, 0) for ignore idx.
const unsigned ignore_idx = 2 * rotset_idx_ignored[0];
cs_arr[ignore_idx] = 1;
cs_arr[ignore_idx + 1] = 0;

auto pcol_ptr = pcol.data() + INDEX_PREFIX;
auto qcol_ptr = qcol.data() + INDEX_PREFIX;
auto pcol_updated_ptr = pcol_updated.data() + INDEX_PREFIX;
auto qcol_updated_ptr = qcol_updated.data() + INDEX_PREFIX;

// Forward pq indices.
pcol_updated[0] = pcol[0];
qcol_updated[0] = qcol[0];

// Parallized loop on update using other columns coefficients
// for (IndexType half_idx = 0; half_idx != halfN; ++half_idx) {
for (IndexType half_idx = 0; half_idx != wsize; ++half_idx) {
const unsigned k = rotset_arr[2 * half_idx + 2 * wstart];
const unsigned l = rotset_arr[2 * half_idx + 1 + 2 * wstart];
// TODO: cleaning pq indices offset.
const unsigned k = rotset_sorted_arr[2 * half_idx + 2 * wstart];
const unsigned l = rotset_sorted_arr[2 * half_idx + 1 + 2 * wstart];

const T c = cs_arr[2 * half_idx + 2 * wstart];
const T s = cs_arr[2 * half_idx + 1 + 2 * wstart];

// 4 coefficients updates!
// TODO: vectorization?!
const T Spk = pcol[k];
const T Spl = pcol[l];
const T Spk = pcol_ptr[k];
const T Spl = pcol_ptr[l];

const T Sqk = qcol[k];
const T Sql = qcol[l];
const T Sqk = qcol_ptr[k];
const T Sql = qcol_ptr[l];

pcol_updated[k] = c * Spk - s * Spl;
pcol_updated[l] = s * Spk + c * Spl;
pcol_updated_ptr[k] = c * Spk - s * Spl;
pcol_updated_ptr[l] = s * Spk + c * Spl;

qcol_updated[k] = c * Sqk - s * Sql;
qcol_updated[l] = s * Sqk + c * Sql;
qcol_updated_ptr[k] = c * Sqk - s * Sql;
qcol_updated_ptr[l] = s * Sqk + c * Sql;
}
return true;
}
};

template <typename T>
void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
T* vqcol_updated, T c, T s,
unsigned short wstart,
unsigned short wend) noexcept {
using T2 = float2;
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol) + wstart;
const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol) + wstart;
T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_updated) + wstart;
T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_updated) + wstart;

for (IndexType idx = 0; idx != wsize; ++idx) {
const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
const T2 vqvec_updated = svec * vpvec + cvec * vqvec;

ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
}
}

/**
* @brief Jacobi algorithm, update of eigen vectors matrix.
*
Expand Down Expand Up @@ -268,32 +337,29 @@ class [[poplar::constraint(
JacobiUpdateEigenvectors();

bool compute(unsigned wid) {
constexpr int INDEX_PREFIX = 2;
const unsigned p = *((unsigned*)vpcol.data());
const unsigned q = *((unsigned*)vqcol.data());

const T c = cs[0];
const T s = cs[1];
const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};

// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol.data()) + wstart;
const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol.data()) + wstart;
T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_out.data()) + wstart;
T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_out.data()) + wstart;

for (IndexType idx = 0; idx != wsize; ++idx) {
const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
const T2 vqvec_updated = svec * vpvec + cvec * vqvec;

ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
// Forwarding p/q (prefix) indices.
vpcol_out[0] = vpcol[0];
vqcol_out[0] = vqcol[0];
// Swapping pointers if necessary.
if (p <= q) {
jacob_update_eigenvectors(
vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX,
vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c,
s, wstart, wend);
} else {
jacob_update_eigenvectors(
vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX,
vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c,
s, wstart, wend);
}
return true;
}
Expand Down
Loading

0 comments on commit 736e9b0

Please sign in to comment.