diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index cc82b17..d9a73a6 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -147,8 +147,8 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - Input> - rotset; // (2,) rotation index p and q. p < q + // Input> + // rotset; // (2,) rotation index p and q. p < q Input> pindex; // (1,) rotation index p and q. p < q Input> diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index b98c54e..141e8f7 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -42,16 +42,16 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_first_step_p = create_ipu_tile_primitive( "jacobi_update_first_step", "JacobiUpdateFirstStep", - inputs=["rotset", "pindex", "qindex", "pcol", "qcol"], + inputs=["pindex", "qindex", "pcol", "qcol"], outputs={ "rotset_sorted": ShapedArray((2,), dtype=np.uint32), "cs": ShapedArray((2,), dtype=np.float32), - "pcol_updated": 3, - "qcol_updated": 4, + "pcol_updated": 2, + "qcol_updated": 3, }, constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size, vector_size=2, wdtype=np.uint16 + inavals[2].size, vector_size=2, wdtype=np.uint16 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -177,30 +177,21 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile with jax.named_scope("rotset"): rotset = jacobi_initial_rotation_set(N) rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles) - pindices_sharded = rotset_sharded[:, :1] qindices_sharded = rotset_sharded[:, 1:] # All different size 2 partitions on columns. for _ in range(1, N): - # On tile constant rotation set tensor building. - # with jax.named_scope("rotset"): - # # rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles) - # pindices_sharded = tile_constant_sharded(np.copy(rotset[:, :1]), tiles=Atiles) - # qindices_sharded = tile_constant_sharded(np.copy(rotset[:, 1:]), tiles=Atiles) - - # pindices_sharded = rotset_sharded[:, :1] - # qindices_sharded = rotset_sharded[:, 1:] - # Compute Schur decomposition + on-tile update of columns. + # Note: not expecting p < q. Sorted inside the vertex. rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore - jacobi_update_first_step_p, rotset_sharded, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N + jacobi_update_first_step_p, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N ) - # print(pindices_sharded.shape, qindices_sharded.shape) # Replicate Schur decomposition across all A tiles: (2*N//2) comms. + with jax.named_scope("rotset_sorted_replicated"): + rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) with jax.named_scope("cs_replicated_sharded"): cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles) - rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) # Just copy Schur decomposition to associated V tiles. cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles) cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles) @@ -226,20 +217,14 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile # Barrier, to make we sync. both set of tiles A and V Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles. 2*N commns per tile. - # NOTE: this inter-tile comm is keeping the p < q property on A and V columns. + # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. with jax.named_scope("Apqcols_rotation"): - # Apcols, Aqcols = tile_rotate_sorted_columns(Apcols, Aqcols, rotset) Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) with jax.named_scope("Vpqcols_rotation"): - # Vpcols, Vqcols = tile_rotate_sorted_columns(Vpcols, Vqcols, rotset) Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) with jax.named_scope("pqindices_rotations"): pindices_sharded, qindices_sharded = tile_rotate_columns(pindices_sharded, qindices_sharded) - # Next rotation set. - rotset = jacobi_next_rotation_set(rotset) - return (Apcols.array, Aqcols.array, Vpcols.array, Vqcols.array) @@ -301,50 +286,6 @@ def permute_pq_indices( return (np.where(rotset_permute_mask, pindices, qindices), np.where(rotset_permute_mask, qindices, pindices)) -def tile_rotate_sorted_columns( - pcols: TileShardedArray, qcols: TileShardedArray, rotset: NDArray[np.uint32] -) -> Tuple[TileShardedArray, TileShardedArray]: - """Rotate columns between tiles using a static `tile_gather`. - - The tricky part of this function is to rotate the columns between tiles, but - keep the property p < q, which means taking care of the present sorting permutation applied - as well the next sorting permutation. - """ - assert pcols.shape == qcols.shape - assert pcols.tiles == qcols.tiles - halfN = pcols.shape[0] - N = halfN * 2 - # Concat all columns, in order to perform a single gather. - all_cols = TileShardedArray( - jax.lax.concatenate([pcols.array, qcols.array], dimension=0), (*pcols.tiles, *qcols.tiles) - ) - - # Start with current indices, in the concat representation of columns - pcols_indices = np.arange(0, halfN, dtype=np.int32) - qcols_indices = np.arange(halfN, N, dtype=np.int32) - # First sorting permutation correction. - rotset_permute_mask = rotset[:, 0] < rotset[:, 1] - pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask) - - # Rotation of columns between tiles (see Jacobi alg.) - # Roughtly: pcols move to the right, qcols to the left. - pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]]) - qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]]) - pcols_indices, qcols_indices = pcols_indices_new, qcols_indices_new - assert len(pcols_indices_new) == halfN - assert len(qcols_indices_new) == halfN - - # Second sorting permutation correction, using the next rotation set. - rotset = jacobi_next_rotation_set(rotset) - rotset_permute_mask = rotset[:, 0] < rotset[:, 1] - pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask) - - # Move columns around + re-split between pcols and qcols. - all_indices = np.concatenate([pcols_indices, qcols_indices]) - all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) - return all_cols_updated[:halfN], all_cols_updated[halfN:] - - def ipu_eigh( x: Array, *, lower: bool = True, symmetrize_input: bool = False, sort_eigenvalues: bool = True, num_iters: int = 1 ) -> Tuple[Array, Array]: