Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 2, 2023
1 parent a5f9004 commit 5971b65
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 29 deletions.
1 change: 1 addition & 0 deletions tessellate_ipu/core/tile_interpreter_vertex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def make_ipu_vector1d_worker_offsets(
Returns:
(6,) number of data vectors per thread.
"""

def make_offsets_fn(sizes):
sizes = [0] + sizes
offsets = np.cumsum(np.array(sizes, wdtype), dtype=wdtype)
Expand Down
89 changes: 63 additions & 26 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,24 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
const unsigned p = pindex[0];
const unsigned q = qindex[0];


const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];

if (p <= q) {
// Proper ordering of p and q already.
jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(),
qcol_updated.data(), cs.data(), p, q, wstart, wend);
qcol_updated.data(), cs.data(), p, q, wstart,
wend);
rotset_sorted[0] = p;
rotset_sorted[1] = q;
}
else {
} else {
// Swap p and q columns as q < p
jacob_update_first_step(qcol.data(), pcol.data(), qcol_updated.data(),
pcol_updated.data(), cs.data(), q, p, wstart, wend);
pcol_updated.data(), cs.data(), q, p, wstart,
wend);
// jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(),
// qcol_updated.data(), cs.data(), q, p, wstart, wend);
// qcol_updated.data(), cs.data(), q, p, wstart,
// wend);
rotset_sorted[0] = q;
rotset_sorted[1] = p;
}
Expand Down Expand Up @@ -271,6 +272,37 @@ class JacobiUpdateSecondStep : public MultiVertex {
}
};

template <typename T>
void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
T* vqcol_updated, T c, T s,
unsigned short wstart,
unsigned short wend) noexcept {
using T2 = float2;
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol) + wstart;
const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol) + wstart;
T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_updated) + wstart;
T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_updated) + wstart;

for (IndexType idx = 0; idx != wsize; ++idx) {
const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
const T2 vqvec_updated = svec * vpvec + cvec * vqvec;

ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
}
}

/**
* @brief Jacobi algorithm, update of eigen vectors matrix.
*
Expand Down Expand Up @@ -304,31 +336,36 @@ class [[poplar::constraint(
bool compute(unsigned wid) {
const T c = cs[0];
const T s = cs[1];
const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};

// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol.data()) + wstart;
const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol.data()) + wstart;
T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_out.data()) + wstart;
T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_out.data()) + wstart;
jacob_update_eigenvectors(vpcol.data(), vqcol.data(), vpcol_out.data(),
vqcol_out.data(), c, s, wstart, wend);
return true;

for (IndexType idx = 0; idx != wsize; ++idx) {
const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);
// const T2 cvec = T2{c, c};
// const T2 svec = T2{s, s};

const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
const T2 vqvec_updated = svec * vpvec + cvec * vqvec;
// // Worker load: start + end vectorized indexes.
// constexpr unsigned ptr_step = 1;
// const IndexType wsize = wend - wstart;

ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
}
return true;
// // pcol, qcol and results pointers.
// const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol.data()) + wstart;
// const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol.data()) + wstart;
// T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_out.data()) + wstart;
// T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_out.data()) + wstart;

// for (IndexType idx = 0; idx != wsize; ++idx) {
// const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
// const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);

// const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
// const T2 vqvec_updated = svec * vpvec + cvec * vqvec;

// ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
// ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
// }
// return true;
}
};
6 changes: 3 additions & 3 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
with jax.named_scope("cs_replicated_sharded"):
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
# Just copy Schur decomposition to associated V tiles.
cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles)
cs_sharded_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
cs_replicated, cs_sharded_Vtiles = tile_data_barrier(cs_replicated, cs_sharded_Vtiles)

# Second Jacobi update step.
# Note: does not require sorting of pcols and qcols.
Expand All @@ -210,7 +210,7 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
# Jacobi eigenvectors update step.
Vpcols, Vqcols = tile_map( # type:ignore
jacobi_update_eigenvectors_p,
cs_Vtiles,
cs_sharded_Vtiles,
Vpcols,
Vqcols,
)
Expand Down

0 comments on commit 5971b65

Please sign in to comment.