From f011bee37b7a4aae66f66100b04f7b8224463bd9 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 19 Oct 2023 15:23:20 +0000 Subject: [PATCH] Optimized the IPU eigh vertex `JacobiUpdateEigenvectors`. Very simple optimization, taking advantage of previously optimized kernel `rotation2d_f32`. 2.5 reduction on vertex cycle counts. --- .../core/vertex/tile_jacobi_vertex.cpp | 32 +++++++------------ tests/linalg/test_tile_linalg_jacobi.py | 4 +-- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 4719a2f..b7aced0 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -363,35 +363,26 @@ class JacobiUpdateSecondStep : public MultiVertex { } }; -template +template void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, T* vqcol_updated, T c, T s, unsigned short wstart, unsigned short wend) noexcept { - using T2 = float2; // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; const IndexType wsize = wend - wstart; + using T2 = float2; + const T2 cs_vec = T2{c, s}; + // pcol, qcol and results pointers. const T2* ptr_pcol = reinterpret_cast(vpcol) + wstart; const T2* ptr_qcol = reinterpret_cast(vqcol) + wstart; T2* ptr_pcol_updated = reinterpret_cast(vpcol_updated) + wstart; T2* ptr_qcol_updated = reinterpret_cast(vqcol_updated) + wstart; - - for (IndexType idx = 0; idx != wsize; ++idx) { - const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 vpvec_updated = cvec * vpvec - svec * vqvec; - const T2 vqvec_updated = svec * vpvec + cvec * vqvec; - - ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); - ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); - } + // Apply Schur2 cs rotation to p/q columns (optimized kernel). + rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, + ptr_qcol_updated, wsize); } /** @@ -400,8 +391,9 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, * See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition, * Johns Hopkins Chapter 8. */ -class [[poplar::constraint( - "elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors +class [[poplar::constraint( + "elem(*vpcol) != elem(*vpcol_out)", + "elem(*vqcol) != elem(*vqcol_out)")]] JacobiUpdateEigenvectors : public MultiVertex { public: using T = float; @@ -439,12 +431,12 @@ class [[poplar::constraint( vqcol_out[0] = vqcol[0]; // Swapping pointers if necessary. if (p <= q) { - jacob_update_eigenvectors( + jacob_update_eigenvectors( vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c, s, wstart, wend); } else { - jacob_update_eigenvectors( + jacob_update_eigenvectors( vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c, s, wstart, wend); diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 15e0e21..e3eb87a 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -128,7 +128,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol): # assert False def test__jacobi_update_eigenvectors_vertex__benchmark_performance(self): - N = 256 + N = 512 tiles = (0,) cs = np.array([0.2, 0.5], dtype=np.float32) pcol = np.random.randn(1, N).astype(np.float32) @@ -158,7 +158,7 @@ def jacobi_update_eigenvectors_fn(cs, pcol, qcol): # Cycle count reference for scale_add: 64(375), 128(467), 256(665), 512(1043) start, end = np.asarray(start)[0], np.asarray(end)[0] qr_correction_cycle_count = end[0] - start[0] - assert qr_correction_cycle_count <= 2200 + assert qr_correction_cycle_count <= 1550 # print("CYCLE count:", qr_correction_cycle_count) # assert False