From 7304ddccb00eaf3a42555f687b5a5470d9e11ac0 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 2 Oct 2023 13:17:17 +0100 Subject: [PATCH] wip --- .../core/tile_interpreter_vertex_utils.py | 1 + .../core/vertex/tile_jacobi_vertex.cpp | 18 ++++++++----- tessellate_ipu/linalg/tile_linalg_jacobi.py | 27 +++++++++++++------ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/tessellate_ipu/core/tile_interpreter_vertex_utils.py b/tessellate_ipu/core/tile_interpreter_vertex_utils.py index 2ba3c2a..f359ab7 100644 --- a/tessellate_ipu/core/tile_interpreter_vertex_utils.py +++ b/tessellate_ipu/core/tile_interpreter_vertex_utils.py @@ -39,6 +39,7 @@ def make_ipu_vector1d_worker_offsets( Returns: (6,) number of data vectors per thread. """ + print("SIZE:", size) def make_offsets_fn(sizes): sizes = [0] + sizes diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 6a9a1ec..cc82b17 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -149,6 +149,10 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep Input> rotset; // (2,) rotation index p and q. p < q + Input> + pindex; // (1,) rotation index p and q. p < q + Input> + qindex; // (1,) rotation index p and q. p < q Input> pcol; // (N,) p column Input> qcol; // (N,) q column @@ -170,8 +174,10 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep JacobiUpdateFirstStep(); bool compute(unsigned wid) { - const unsigned p = rotset[0]; - const unsigned q = rotset[1]; + const unsigned p = pindex[0]; + const unsigned q = qindex[0]; + + const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; @@ -179,8 +185,8 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep // Proper ordering of p and q already. jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(), qcol_updated.data(), cs.data(), p, q, wstart, wend); - rotset_sorted[0] = rotset[0]; - rotset_sorted[1] = rotset[1]; + rotset_sorted[0] = p; + rotset_sorted[1] = q; } else { // Swap p and q columns as q < p @@ -188,8 +194,8 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep pcol_updated.data(), cs.data(), q, p, wstart, wend); // jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(), // qcol_updated.data(), cs.data(), q, p, wstart, wend); - rotset_sorted[0] = rotset[1]; - rotset_sorted[1] = rotset[0]; + rotset_sorted[0] = q; + rotset_sorted[1] = p; } return true; diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 00bf766..b98c54e 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -42,16 +42,16 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_first_step_p = create_ipu_tile_primitive( "jacobi_update_first_step", "JacobiUpdateFirstStep", - inputs=["rotset", "pcol", "qcol"], + inputs=["rotset", "pindex", "qindex", "pcol", "qcol"], outputs={ "rotset_sorted": ShapedArray((2,), dtype=np.uint32), "cs": ShapedArray((2,), dtype=np.float32), - "pcol_updated": 1, - "qcol_updated": 2, + "pcol_updated": 3, + "qcol_updated": 4, }, constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[1].size, vector_size=2, wdtype=np.uint16 + inavals[3].size, vector_size=2, wdtype=np.uint16 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -178,16 +178,25 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile rotset = jacobi_initial_rotation_set(N) rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles) + pindices_sharded = rotset_sharded[:, :1] + qindices_sharded = rotset_sharded[:, 1:] + # All different size 2 partitions on columns. for _ in range(1, N): # On tile constant rotation set tensor building. - with jax.named_scope("rotset"): - rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles) + # with jax.named_scope("rotset"): + # # rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles) + # pindices_sharded = tile_constant_sharded(np.copy(rotset[:, :1]), tiles=Atiles) + # qindices_sharded = tile_constant_sharded(np.copy(rotset[:, 1:]), tiles=Atiles) + + # pindices_sharded = rotset_sharded[:, :1] + # qindices_sharded = rotset_sharded[:, 1:] # Compute Schur decomposition + on-tile update of columns. rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore - jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N + jacobi_update_first_step_p, rotset_sharded, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N ) + # print(pindices_sharded.shape, qindices_sharded.shape) # Replicate Schur decomposition across all A tiles: (2*N//2) comms. with jax.named_scope("cs_replicated_sharded"): cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles) @@ -197,6 +206,7 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles) # Second Jacobi update step. + # Note: does not require sorting of pcols and qcols. cs_replicated, Apcols, Aqcols = tile_map( # type:ignore jacobi_update_second_step_p, cs_replicated, @@ -221,10 +231,11 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile with jax.named_scope("Apqcols_rotation"): # Apcols, Aqcols = tile_rotate_sorted_columns(Apcols, Aqcols, rotset) Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - with jax.named_scope("Vpqcols_rotation"): # Vpcols, Vqcols = tile_rotate_sorted_columns(Vpcols, Vqcols, rotset) Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + with jax.named_scope("pqindices_rotations"): + pindices_sharded, qindices_sharded = tile_rotate_columns(pindices_sharded, qindices_sharded) # Next rotation set. rotset = jacobi_next_rotation_set(rotset)