From 491275aea6f8d67f5ff396867fc785c2fb12125f Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 5 Oct 2023 21:14:13 +0100 Subject: [PATCH] IPU Jacobi eigh `fori_loop` transition (#42) This PR is switching the IPU Jacobi `eigh` implementation from a Python loop to a JAX `fori_loop`, reducing massively code size and allowing Eigen decomposition on larger matrices. Note: in terms of performance, some issues still remain as Poplar compiler is adding some extra on-tile copies which in theory could be eluded. --- .../core/vertex/tile_jacobi_vertex.cpp | 264 +++++++++++------- tessellate_ipu/linalg/tile_linalg_jacobi.py | 259 ++++++++++------- tests/linalg/test_tile_linalg_jacobi.py | 29 +- 3 files changed, 334 insertions(+), 218 deletions(-) diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 4c33707..0cb74c0 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -78,6 +78,59 @@ class JacobiSymSchur2 : public Vertex { } }; +template +void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, + T* qcol_updated, T* cs, unsigned p, unsigned q, + unsigned short wstart, + unsigned short wend) noexcept { + using T2 = float2; + using IndexType = unsigned short; + + const T Apq = pcol[q]; + const T App = pcol[p]; + const T Aqq = qcol[q]; + + // Schur2 decomposition. + const T2 cs_vec = sym_schur2(App, Aqq, Apq); + const T& c = cs_vec[0]; + const T& s = cs_vec[1]; + cs[0] = c; + cs[1] = s; + + // Worker load: start + end vectorized indexes. + constexpr unsigned ptr_step = 1; + const IndexType wsize = wend - wstart; + + // pcol, qcol and results pointers. + const float2* ptr_pcol = reinterpret_cast(pcol) + wstart; + const float2* ptr_qcol = reinterpret_cast(qcol) + wstart; + float2* ptr_pcol_updated = reinterpret_cast(pcol_updated) + wstart; + float2* ptr_qcol_updated = reinterpret_cast(qcol_updated) + wstart; + + const T2 cvec = T2{c, c}; + const T2 svec = T2{s, s}; + + // Easier to vectorized + parallelize if start with normal update first. + for (IndexType idx = 0; idx != wsize; ++idx) { + // TODO: investigate assembly? + const T2 pvec = ipu::load_postinc(&ptr_pcol, 1); + const T2 qvec = ipu::load_postinc(&ptr_qcol, 1); + + const T2 pvec_updated = cvec * pvec - svec * qvec; + const T2 qvec_updated = svec * pvec + cvec * qvec; + + ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1); + ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1); + } + + // Update main values App, Apq, Aqq + pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; + qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; + // Zero on purpose with Schur decomposition! + pcol_updated[q] = 0; + qcol_updated[p] = 0; +} + /** * @brief Jacobi algorithm, update first step: schur + column update. * @@ -92,78 +145,55 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - Input> - rotset; // (2,) rotation index p and q. p < q - Input> pcol; // (N,) p column - Input> qcol; // (N,) q column + // p/q cols + index prefix (2 x uint32). + Input> pcol; // (N + 2,) p column + Input> qcol; // (N + 2,) q column Input> worker_offsets; // (7,) threads work size + 1. + Output> + rotset_sorted; // (3,) rotset index sorted + was sorted? Output> cs; // (2,) (c, s) Schur decomposition values Output> - pcol_updated; // (N,) p column updated + pcol_updated; // (N + 2,) p column updated Output> - qcol_updated; // (N,) q column updated - - const IndexType N; // size + qcol_updated; // (N + 2,) q column updated JacobiUpdateFirstStep(); bool compute(unsigned wid) { - const unsigned p = rotset[0]; - const unsigned q = rotset[1]; - const T Apq = pcol[q]; - const T App = pcol[p]; - const T Aqq = qcol[q]; + // Size of the index prefix in pcol and qcol. + constexpr int INDEX_PREFIX = 2; + const unsigned p = *((unsigned*)pcol.data()); + const unsigned q = *((unsigned*)qcol.data()); - // Schur2 decomposition. - const T2 cs_vec = sym_schur2(App, Aqq, Apq); - const T& c = cs_vec[0]; - const T& s = cs_vec[1]; - cs[0] = c; - cs[1] = s; - - // Worker load: start + end vectorized indexes. - constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; - // pcol, qcol and results pointers. - const float2* ptr_pcol = - reinterpret_cast(pcol.data()) + wstart; - const float2* ptr_qcol = - reinterpret_cast(qcol.data()) + wstart; - float2* ptr_pcol_updated = - reinterpret_cast(pcol_updated.data()) + wstart; - float2* ptr_qcol_updated = - reinterpret_cast(qcol_updated.data()) + wstart; - - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; - - // Easier to vectorized + parallelize if start with normal update first. - for (IndexType idx = 0; idx != wsize; ++idx) { - // TODO: investigate assembly? - const T2 pvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 qvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 pvec_updated = cvec * pvec - svec * qvec; - const T2 qvec_updated = svec * pvec + cvec * qvec; - - ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1); - ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1); + // Forward p/q indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; + + if (p <= q) { + // Proper ordering of p and q already. + jacob_update_first_step( + pcol.data() + INDEX_PREFIX, qcol.data() + INDEX_PREFIX, + pcol_updated.data() + INDEX_PREFIX, + qcol_updated.data() + INDEX_PREFIX, cs.data(), p, q, wstart, wend); + rotset_sorted[0] = p; + rotset_sorted[1] = q; + } else { + // Swap p and q columns as q < p + jacob_update_first_step( + qcol.data() + INDEX_PREFIX, pcol.data() + INDEX_PREFIX, + qcol_updated.data() + INDEX_PREFIX, + pcol_updated.data() + INDEX_PREFIX, cs.data(), q, p, wstart, wend); + rotset_sorted[0] = q; + rotset_sorted[1] = p; } - - // Update main values App, Apq, Aqq - pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; - qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; - // Zero on purpose with Schur decomposition! - pcol_updated[q] = 0; - qcol_updated[p] = 0; return true; } }; @@ -178,65 +208,104 @@ class JacobiUpdateSecondStep : public MultiVertex { InOut> cs_arr; // (N/2, 2) (c, s) values Input> - rotset_arr; // (N/2, 2) (p, q) array values. p < q + rotset_sorted_arr; // (N/2, 2) (p, q) array values. p < q Input> rotset_idx_ignored; // (1,) index in rotset to ignore. Input> worker_offsets; // (7,) threads work size + 1. - Input> pcol; // (N,) p column - Input> qcol; // (N,) q column + Input> pcol; // (N+2,) p column + Input> qcol; // (N+2,) q column Output> - pcol_updated; // (N,) p column updated + pcol_updated; // (N+2,) p column updated Output> - qcol_updated; // (N,) q column updated - - // const unsigned ignore_idx; // cs/pq index to ignore. - const IndexType halfN; // N / 2 + qcol_updated; // (N+2,) q column updated JacobiUpdateSecondStep(); bool compute(unsigned wid) { - // Use (p, q) = (1, 0) for ignore idx. - const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; - cs_arr[ignore_idx] = 1; - cs_arr[ignore_idx + 1] = 0; - + // Size of the index prefix in pcol and qcol. + constexpr int INDEX_PREFIX = 2; // Worker load: start + end vectorized indexes. constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; const IndexType wsize = wend - wstart; + // Use (p, q) = (1, 0) for ignore idx. + const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; + cs_arr[ignore_idx] = 1; + cs_arr[ignore_idx + 1] = 0; + + auto pcol_ptr = pcol.data() + INDEX_PREFIX; + auto qcol_ptr = qcol.data() + INDEX_PREFIX; + auto pcol_updated_ptr = pcol_updated.data() + INDEX_PREFIX; + auto qcol_updated_ptr = qcol_updated.data() + INDEX_PREFIX; + + // Forward pq indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; + // Parallized loop on update using other columns coefficients - // for (IndexType half_idx = 0; half_idx != halfN; ++half_idx) { for (IndexType half_idx = 0; half_idx != wsize; ++half_idx) { - const unsigned k = rotset_arr[2 * half_idx + 2 * wstart]; - const unsigned l = rotset_arr[2 * half_idx + 1 + 2 * wstart]; + // TODO: cleaning pq indices offset. + const unsigned k = rotset_sorted_arr[2 * half_idx + 2 * wstart]; + const unsigned l = rotset_sorted_arr[2 * half_idx + 1 + 2 * wstart]; const T c = cs_arr[2 * half_idx + 2 * wstart]; const T s = cs_arr[2 * half_idx + 1 + 2 * wstart]; // 4 coefficients updates! // TODO: vectorization?! - const T Spk = pcol[k]; - const T Spl = pcol[l]; + const T Spk = pcol_ptr[k]; + const T Spl = pcol_ptr[l]; - const T Sqk = qcol[k]; - const T Sql = qcol[l]; + const T Sqk = qcol_ptr[k]; + const T Sql = qcol_ptr[l]; - pcol_updated[k] = c * Spk - s * Spl; - pcol_updated[l] = s * Spk + c * Spl; + pcol_updated_ptr[k] = c * Spk - s * Spl; + pcol_updated_ptr[l] = s * Spk + c * Spl; - qcol_updated[k] = c * Sqk - s * Sql; - qcol_updated[l] = s * Sqk + c * Sql; + qcol_updated_ptr[k] = c * Sqk - s * Sql; + qcol_updated_ptr[l] = s * Sqk + c * Sql; } return true; } }; +template +void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, + T* vqcol_updated, T c, T s, + unsigned short wstart, + unsigned short wend) noexcept { + using T2 = float2; + // Using `uint16` seems to be generating more efficient loops? + using IndexType = unsigned short; + + const T2 cvec = T2{c, c}; + const T2 svec = T2{s, s}; + const IndexType wsize = wend - wstart; + + // pcol, qcol and results pointers. + const T2* ptr_pcol = reinterpret_cast(vpcol) + wstart; + const T2* ptr_qcol = reinterpret_cast(vqcol) + wstart; + T2* ptr_pcol_updated = reinterpret_cast(vpcol_updated) + wstart; + T2* ptr_qcol_updated = reinterpret_cast(vqcol_updated) + wstart; + + for (IndexType idx = 0; idx != wsize; ++idx) { + const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); + const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); + + const T2 vpvec_updated = cvec * vpvec - svec * vqvec; + const T2 vqvec_updated = svec * vpvec + cvec * vqvec; + + ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); + ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); + } +} + /** * @brief Jacobi algorithm, update of eigen vectors matrix. * @@ -268,32 +337,29 @@ class [[poplar::constraint( JacobiUpdateEigenvectors(); bool compute(unsigned wid) { + constexpr int INDEX_PREFIX = 2; + const unsigned p = *((unsigned*)vpcol.data()); + const unsigned q = *((unsigned*)vqcol.data()); + const T c = cs[0]; const T s = cs[1]; - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; - - // Worker load: start + end vectorized indexes. - constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; - - // pcol, qcol and results pointers. - const T2* ptr_pcol = reinterpret_cast(vpcol.data()) + wstart; - const T2* ptr_qcol = reinterpret_cast(vqcol.data()) + wstart; - T2* ptr_pcol_updated = reinterpret_cast(vpcol_out.data()) + wstart; - T2* ptr_qcol_updated = reinterpret_cast(vqcol_out.data()) + wstart; - for (IndexType idx = 0; idx != wsize; ++idx) { - const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 vpvec_updated = cvec * vpvec - svec * vqvec; - const T2 vqvec_updated = svec * vpvec + cvec * vqvec; - - ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); - ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); + // Forwarding p/q (prefix) indices. + vpcol_out[0] = vpcol[0]; + vqcol_out[0] = vqcol[0]; + // Swapping pointers if necessary. + if (p <= q) { + jacob_update_eigenvectors( + vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX, + vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c, + s, wstart, wend); + } else { + jacob_update_eigenvectors( + vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX, + vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c, + s, wstart, wend); } return true; } diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 1e489ee..3cbfa5f 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -7,10 +7,10 @@ import numpy as np from jax.core import ShapedArray +# import tessellate_ipu from tessellate_ipu import ( TileShardedArray, create_ipu_tile_primitive, - tile_constant_replicated, tile_constant_sharded, tile_data_barrier, tile_gather, @@ -24,6 +24,11 @@ Array = Any +INDEX_PREFIX = 2 +"""Index prefix size in p/q columns. +""" + + def get_jacobi_vertex_gp_filename() -> str: return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_jacobi_vertex.cpp") @@ -41,11 +46,16 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_first_step_p = create_ipu_tile_primitive( "jacobi_update_first_step", "JacobiUpdateFirstStep", - inputs=["rotset", "pcol", "qcol"], - outputs={"cs": ShapedArray((2,), dtype=np.float32), "pcol_updated": 1, "qcol_updated": 2}, + inputs=["pcol", "qcol"], + outputs={ + "rotset_sorted": ShapedArray((2,), dtype=np.uint32), + "cs": ShapedArray((2,), dtype=np.float32), + "pcol_updated": 0, + "qcol_updated": 1, + }, constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[1].size, vector_size=2, wdtype=np.uint16 + inavals[0].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -56,11 +66,11 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_second_step_p = create_ipu_tile_primitive( "jacobi_update_second_step", "JacobiUpdateSecondStep", - inputs=["cs_arr", "rotset_arr", "rotset_idx_ignored", "pcol", "qcol"], + inputs=["cs_arr", "rotset_sorted_arr", "rotset_idx_ignored", "pcol", "qcol"], outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size, vector_size=2, wdtype=np.uint16 + inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -74,7 +84,10 @@ def get_jacobi_vertex_gp_filename() -> str: outputs={"vpcol_out": 1, "vqcol_out": 2}, # Bug when inplace update? constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[1].size, vector_size=2, wdtype=np.uint16 + # Remove 2 for pq indices prefix. + inavals[1].size - INDEX_PREFIX, + vector_size=2, + wdtype=np.uint16, ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -88,6 +101,49 @@ def jacobi_initial_rotation_set(N: int) -> NDArray[np.uint32]: return rot +def jacobi_initial_pqindices(N: int) -> Tuple[NDArray[np.uint32], NDArray[np.uint32]]: + """Jacobi initial p/q indices arrays. + Padded to (N/2, 2) for 64bits alignment. + + Returns: + A tuple of p/q indices arrays. + """ + rotset = jacobi_initial_rotation_set(N) + pindices = rotset[:, :1] + qindices = rotset[:, 1:] + pindices = np.concatenate([pindices, pindices], axis=1) + qindices = np.concatenate([qindices, qindices], axis=1) + return (pindices, qindices) + + +def tile_sharded_pq_columns( + pcols: Array, qcols: Array, tiles: Tuple[int, ...] +) -> Tuple[TileShardedArray, TileShardedArray]: + """Tile sharding of p/q columns arrays + adding indexing prefix. + + Args: + pcols/qcols: (M, N) arrays. + tiles: Collection of tiles to shard on. + Returns: + Pair of tile sharded array (M, N+2), with indexing prefix. + """ + assert pcols.shape == qcols.shape + assert len(pcols.shape) == 2 + N = pcols.shape[0] * 2 + # N = pcols.shape[-1] + + pindices, qindices = jacobi_initial_pqindices(N) + pindices_prefix = tile_constant_sharded(pindices.view(np.float32), tiles=tiles) + qindices_prefix = tile_constant_sharded(qindices.view(np.float32), tiles=tiles) + # Prepend the p/q indices. Note: keeping 64bits alignment with 2 uint32s. + pcols = jax.lax.concatenate([pindices_prefix.array, pcols], dimension=1) + qcols = jax.lax.concatenate([qindices_prefix.array, qcols], dimension=1) + # Shard between tiles. TODO: single call with tuple. + pcols = tile_put_sharded(pcols, tiles=tiles) + qcols = tile_put_sharded(qcols, tiles=tiles) + return pcols, qcols + + def jacobi_next_rotation_set(rot: NDArray[np.uint32]) -> NDArray[np.uint32]: """Jacobi next rotation set (N/2, 2). @@ -112,82 +168,116 @@ def jacobi_sort_rotation_set(rotset: NDArray[np.uint32]) -> NDArray[np.uint32]: return np.stack([pindices, qindices], axis=-1) -def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtiles: Any) -> Tuple[Array, ...]: - """IPU Eigen decomposition: single iteration of the Jacobi algorithm. +def tile_rotate_columns(pcols: TileShardedArray, qcols: TileShardedArray) -> Tuple[TileShardedArray, TileShardedArray]: + """Rotate columns between tiles using a static `tile_gather`. - NOTE: the goal is to have a function which can be easily combined with `fori_loop`. + We follow the Jacobi rotation patterns between tiles. In short + - moving `pcols` to the "left" + - moving `qcols` to the "right" + """ + assert pcols.shape == qcols.shape + assert pcols.tiles == qcols.tiles + halfN = pcols.shape[0] + N = halfN * 2 + # Concat all columns, in order to perform a single gather. + all_cols = TileShardedArray( + jax.lax.concatenate([pcols.array, qcols.array], dimension=0), (*pcols.tiles, *qcols.tiles) + ) + + pcols_indices = np.arange(0, halfN, dtype=np.int32) + qcols_indices = np.arange(halfN, N, dtype=np.int32) + # Rotation of columns between tiles (see Jacobi alg.) + # Roughtly: pcols move to the right, qcols to the left. + pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]]) + qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]]) + # Move columns around! + all_indices = np.concatenate([pcols_indices_new, qcols_indices_new]) + all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) + return all_cols_updated[:halfN], all_cols_updated[halfN:] + + +def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tuple[TileShardedArray, ...]: + """IPU Jacobi eigen-decomposition algorithm main body. Args: - all_AV_cols: A and V matrices p/q columns. - Atiles: A matrix tiles. - Vtiles: V matrix tiles. + idx: Loop index. + inputs: Tile sharded Apcols, Aqcols, Vpcols, Vqcols Returns: - Tuple of updated A and V matrices p/q columns. + Apcols, Aqcols, Vpcols, Vqcols after a main Jacobi update. """ - Apcols, Aqcols, Vpcols, Vqcols = all_AV_cols - N = Apcols.shape[-1] - halfN = N // 2 - # TODO: check compatibility of TileShardedArray with fori_loop - # Shard arrays across tiles. - Apcols = tile_put_sharded(Apcols, tiles=Atiles) - Aqcols = tile_put_sharded(Aqcols, tiles=Atiles) - # Initial eigenvectors (identity matrix). - Vpcols = tile_put_sharded(Vpcols, tiles=Vtiles) - Vqcols = tile_put_sharded(Vqcols, tiles=Vtiles) - # Constant tensor of index to ignored at every iteration. - rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles) - rotset = jacobi_initial_rotation_set(N) + Apcols, Aqcols, Vpcols, Vqcols = inputs + Atiles = Apcols.tiles + Vtiles = Vpcols.tiles + halfN = Apcols.shape[0] - # All different size 2 partitions on columns. - for _ in range(1, N): - # Sorted rotation set: p < q indices. - rotset_sorted = jacobi_sort_rotation_set(rotset) - # On tile constant rotation set tensor building. - with jax.named_scope("rotset"): - rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles) - rotset_sharded = tile_constant_sharded(rotset_sorted, tiles=Atiles) + with jax.named_scope("jacobi_eigh"): + # Sharded constant with p/q indices to ignore in second update stage. + with jax.named_scope("rotset_index_ignored"): + rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles) # Compute Schur decomposition + on-tile update of columns. - cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore - jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N + # Note: not expecting p < q. Input pcols/qcols sorted inside the vertex. + rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore + jacobi_update_first_step_p, Apcols, Aqcols ) - # Replicate Schur decomposition across all A tiles: (2*N//2) comms. + # Replicate Schur decomposition + rotset across all A tiles: (2*N//2) comms. + with jax.named_scope("rotset_sorted_replicated"): + rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) with jax.named_scope("cs_replicated_sharded"): cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles) - # Just copy Schur decomposition to associated V tiles. - cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles) - cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles) + # Just copy Schur decomposition to associated V tiles (no need to replicate). + cs_sharded_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles) + # Barrier to force all communications to be fused. + cs_replicated, cs_sharded_Vtiles, rotset_sorted_replicated = tile_data_barrier( + cs_replicated, cs_sharded_Vtiles, rotset_sorted_replicated + ) # Second Jacobi update step. + # Note: does not require sorting of pcols and qcols. cs_replicated, Apcols, Aqcols = tile_map( # type:ignore jacobi_update_second_step_p, cs_replicated, - rotset_replicated, + rotset_sorted_replicated, rotset_index_ignored, Apcols, Aqcols, - halfN=halfN, ) # Jacobi eigenvectors update step. Vpcols, Vqcols = tile_map( # type:ignore jacobi_update_eigenvectors_p, - cs_Vtiles, + cs_sharded_Vtiles, Vpcols, Vqcols, ) # Barrier, to make we sync. both set of tiles A and V Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles. 2*N commns per tile. - # NOTE: this inter-tile comm is keeping the p < q property on A and V columns. + # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. with jax.named_scope("Apqcols_rotation"): - Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset) + Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) with jax.named_scope("Vpqcols_rotation"): - Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols, rotset) - # Next rotation set. - rotset = jacobi_next_rotation_set(rotset) + Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + return Apcols, Aqcols, Vpcols, Vqcols - return (Apcols.array, Aqcols.array, Vpcols.array, Vqcols.array) + +def ipu_jacobi_eigh_iteration(idx: Array, all_AV_cols: Tuple[TileShardedArray, ...]) -> Tuple[TileShardedArray, ...]: + """IPU Eigen decomposition: single iteration of the Jacobi algorithm. + + NOTE: the goal is to have a function which can be easily combined with `fori_loop`. + + Args: + all_AV_cols: A and V sharded p/q columns + index prefixing.. + Returns: + Tuple of updated A and V matrices p/q columns. + """ + Apcols, Aqcols, Vpcols, Vqcols = all_AV_cols + shape = Apcols.shape + assert len(shape) == 2 + assert shape[0] * 2 + INDEX_PREFIX == shape[1] + N = shape[0] * 2 + # Jacobi eigh iteration as a single fori_loop. + Apcols, Aqcols, Vpcols, Vqcols = jax.lax.fori_loop(1, N, ipu_jacobi_eigh_body, (Apcols, Aqcols, Vpcols, Vqcols)) + return (Apcols, Aqcols, Vpcols, Vqcols) def ipu_jacobi_eigh(x: Array, num_iters: int = 1) -> Tuple[Array, Array]: @@ -205,22 +295,27 @@ def ipu_jacobi_eigh(x: Array, num_iters: int = 1) -> Tuple[Array, Array]: assert N <= 1024 halfN = N // 2 - Atiles = tuple(range(0, halfN)) - Vtiles = tuple(range(halfN, 2 * halfN)) + tile_offset = 1 + Atiles = tuple(range(tile_offset, halfN + tile_offset)) + Vtiles = tuple(range(halfN + tile_offset, 2 * halfN + tile_offset)) # Initial "eigenvalues" matrix. - Apcols = jax.lax.slice_in_dim(x, 0, N, stride=2) - Aqcols = jax.lax.slice_in_dim(x, 1, N, stride=2) + Apcols_init = jax.lax.slice_in_dim(x, 0, N, stride=2) + Aqcols_init = jax.lax.slice_in_dim(x, 1, N, stride=2) # Initial eigenvectors (identity matrix). - Vpcols = np.identity(N)[0::2] - Vqcols = np.identity(N)[1::2] + Vpcols_init = np.identity(N)[0::2] + Vqcols_init = np.identity(N)[1::2] - # Set A and V tiling static. - eigh_iteration_fn = lambda _, x: ipu_jacobi_eigh_iteration(x, Atiles, Vtiles) + # Shard p/q columns + adding index prefix. + Apcols, Aqcols = tile_sharded_pq_columns(Apcols_init, Aqcols_init, tiles=Atiles) + Vpcols, Vqcols = tile_sharded_pq_columns(Vpcols_init, Vqcols_init, tiles=Vtiles) # JAX fori_loop => no Python unrolling and code bloating! Apcols, Aqcols, Vpcols, Vqcols = jax.lax.fori_loop( - 0, num_iters, eigh_iteration_fn, (Apcols, Aqcols, Vpcols, Vqcols) + 0, num_iters, ipu_jacobi_eigh_iteration, (Apcols, Aqcols, Vpcols, Vqcols) ) + # Back to JAX arrays, removing indexing prefix. + (Apcols, Aqcols, Vpcols, Vqcols) = map(lambda x: x.array[:, INDEX_PREFIX:], (Apcols, Aqcols, Vpcols, Vqcols)) + # Expect the output to follow the initial rotation set columns split. rotset = jacobi_initial_rotation_set(N) # Re-organize pcols and qcols into the result matrix. @@ -248,50 +343,6 @@ def permute_pq_indices( return (np.where(rotset_permute_mask, pindices, qindices), np.where(rotset_permute_mask, qindices, pindices)) -def tile_rotate_columns( - pcols: TileShardedArray, qcols: TileShardedArray, rotset: NDArray[np.uint32] -) -> Tuple[TileShardedArray, TileShardedArray]: - """Rotate columns between tiles using a static `tile_gather`. - - The tricky part of this function is to rotate the columns between tiles, but - keep the property p < q, which means taking care of the present sorting permutation applied - as well the next sorting permutation. - """ - assert pcols.shape == qcols.shape - assert pcols.tiles == qcols.tiles - halfN = pcols.shape[0] - N = halfN * 2 - # Concat all columns, in order to perform a single gather. - all_cols = TileShardedArray( - jax.lax.concatenate([pcols.array, qcols.array], dimension=0), (*pcols.tiles, *qcols.tiles) - ) - - # Start with current indices, in the concat representation of columns - pcols_indices = np.arange(0, halfN, dtype=np.int32) - qcols_indices = np.arange(halfN, N, dtype=np.int32) - # First sorting permutation correction. - rotset_permute_mask = rotset[:, 0] < rotset[:, 1] - pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask) - - # Rotation of columns between tiles (see Jacobi alg.) - # Roughtly: pcols move to the right, qcols to the left. - pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]]) - qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]]) - pcols_indices, qcols_indices = pcols_indices_new, qcols_indices_new - assert len(pcols_indices_new) == halfN - assert len(qcols_indices_new) == halfN - - # Second sorting permutation correction, using the next rotation set. - rotset = jacobi_next_rotation_set(rotset) - rotset_permute_mask = rotset[:, 0] < rotset[:, 1] - pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask) - - # Move columns around + re-split between pcols and qcols. - all_indices = np.concatenate([pcols_indices, qcols_indices]) - all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) - return all_cols_updated[:halfN], all_cols_updated[halfN:] - - def ipu_eigh( x: Array, *, lower: bool = True, symmetrize_input: bool = False, sort_eigenvalues: bool = True, num_iters: int = 1 ) -> Tuple[Array, Array]: diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 92fe4ad..01f43b1 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -19,6 +19,7 @@ jacobi_sym_schur2_p, jacobi_update_eigenvectors_p, jacobi_update_first_step_p, + tile_sharded_pq_columns, ) from tessellate_ipu.utils import IpuTargetType @@ -102,18 +103,17 @@ def test__jacobi_update_first_step_vertex__benchmark_performance(self): N = 128 tiles = (0,) pq = np.array([3, N // 2], dtype=np.uint32) - pcol = np.random.randn(N).astype(np.float32) - qcol = np.random.randn(N).astype(np.float32) + pcol = np.random.randn(1, N).astype(np.float32) + qcol = np.random.randn(1, N).astype(np.float32) def jacobi_update_first_step_fn(pq, pcol, qcol): pq = tile_put_replicated(pq, tiles) - pcol = tile_put_replicated(pcol, tiles) - qcol = tile_put_replicated(qcol, tiles) + pcol, qcol = tile_sharded_pq_columns(pcol, qcol, tiles) # Force synchronization at this point, before cycle count. pq, pcol, qcol = tile_data_barrier(pq, pcol, qcol) pcol, start = ipu_cycle_count(pcol) - cs, _, _ = tile_map( # type:ignore - jacobi_update_first_step_p, pq, pcol, qcol, N=N + _, cs, _, _ = tile_map( # type:ignore + jacobi_update_first_step_p, pcol, qcol ) cs, end = ipu_cycle_count(cs) return cs, start, end @@ -123,7 +123,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol): start, end = np.asarray(start)[0], np.asarray(end)[0] qr_correction_cycle_count = end[0] - start[0] - assert qr_correction_cycle_count <= 1600 + assert qr_correction_cycle_count <= 1700 # print("CYCLE count:", qr_correction_cycle_count) # assert False @@ -131,13 +131,12 @@ def test__jacobi_update_eigenvectors_vertex__benchmark_performance(self): N = 256 tiles = (0,) cs = np.array([0.2, 0.5], dtype=np.float32) - pcol = np.random.randn(N).astype(np.float32) - qcol = np.random.randn(N).astype(np.float32) + pcol = np.random.randn(1, N).astype(np.float32) + qcol = np.random.randn(1, N).astype(np.float32) def jacobi_update_eigenvectors_fn(cs, pcol, qcol): cs = tile_put_replicated(cs, tiles) - pcol = tile_put_replicated(pcol, tiles) - qcol = tile_put_replicated(qcol, tiles) + pcol, qcol = tile_sharded_pq_columns(pcol, qcol, tiles) # Force synchronization at this point, before cycle count. cs, pcol, qcol = tile_data_barrier(cs, pcol, qcol) pcol, start = ipu_cycle_count(pcol) @@ -152,14 +151,14 @@ def jacobi_update_eigenvectors_fn(cs, pcol, qcol): pcol_updated = np.asarray(pcol_updated) qcol_updated = np.asarray(qcol_updated) - # Make sure we have the right result! - npt.assert_array_almost_equal(pcol_updated[0], pcol * cs[0] - qcol * cs[1]) - npt.assert_array_almost_equal(qcol_updated[0], pcol * cs[1] + qcol * cs[0]) + # Make sure we have the right result! NOTE: discarding indexing prefix! + npt.assert_array_almost_equal(pcol_updated[:, 2:], pcol * cs[0] - qcol * cs[1]) + npt.assert_array_almost_equal(qcol_updated[:, 2:], pcol * cs[1] + qcol * cs[0]) # Cycle count reference for scale_add: 64(375), 128(467), 256(665), 512(1043) start, end = np.asarray(start)[0], np.asarray(end)[0] qr_correction_cycle_count = end[0] - start[0] - assert qr_correction_cycle_count <= 2000 + assert qr_correction_cycle_count <= 2200 # print("CYCLE count:", qr_correction_cycle_count) # assert False