Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 2, 2023
1 parent 7304ddc commit 49a2e49
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 70 deletions.
4 changes: 2 additions & 2 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
rotset; // (2,) rotation index p and q. p < q
// Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
// rotset; // (2,) rotation index p and q. p < q
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
pindex; // (1,) rotation index p and q. p < q
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
Expand Down
77 changes: 9 additions & 68 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def get_jacobi_vertex_gp_filename() -> str:
jacobi_update_first_step_p = create_ipu_tile_primitive(
"jacobi_update_first_step",
"JacobiUpdateFirstStep",
inputs=["rotset", "pindex", "qindex", "pcol", "qcol"],
inputs=["pindex", "qindex", "pcol", "qcol"],
outputs={
"rotset_sorted": ShapedArray((2,), dtype=np.uint32),
"cs": ShapedArray((2,), dtype=np.float32),
"pcol_updated": 3,
"qcol_updated": 4,
"pcol_updated": 2,
"qcol_updated": 3,
},
constants={
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
inavals[3].size, vector_size=2, wdtype=np.uint16
inavals[2].size, vector_size=2, wdtype=np.uint16
)
},
gp_filename=get_jacobi_vertex_gp_filename(),
Expand Down Expand Up @@ -177,30 +177,21 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
with jax.named_scope("rotset"):
rotset = jacobi_initial_rotation_set(N)
rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)

pindices_sharded = rotset_sharded[:, :1]
qindices_sharded = rotset_sharded[:, 1:]

# All different size 2 partitions on columns.
for _ in range(1, N):
# On tile constant rotation set tensor building.
# with jax.named_scope("rotset"):
# # rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
# pindices_sharded = tile_constant_sharded(np.copy(rotset[:, :1]), tiles=Atiles)
# qindices_sharded = tile_constant_sharded(np.copy(rotset[:, 1:]), tiles=Atiles)

# pindices_sharded = rotset_sharded[:, :1]
# qindices_sharded = rotset_sharded[:, 1:]

# Compute Schur decomposition + on-tile update of columns.
# Note: not expecting p < q. Sorted inside the vertex.
rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore
jacobi_update_first_step_p, rotset_sharded, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N
jacobi_update_first_step_p, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N
)
# print(pindices_sharded.shape, qindices_sharded.shape)
# Replicate Schur decomposition across all A tiles: (2*N//2) comms.
with jax.named_scope("rotset_sorted_replicated"):
rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles)
with jax.named_scope("cs_replicated_sharded"):
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles)
# Just copy Schur decomposition to associated V tiles.
cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles)
Expand All @@ -226,20 +217,14 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile

# Barrier, to make we sync. both set of tiles A and V
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# Move columns between tiles. 2*N commns per tile.
# NOTE: this inter-tile comm is keeping the p < q property on A and V columns.
# Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile.
with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_sorted_columns(Apcols, Aqcols, rotset)
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_sorted_columns(Vpcols, Vqcols, rotset)
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
with jax.named_scope("pqindices_rotations"):
pindices_sharded, qindices_sharded = tile_rotate_columns(pindices_sharded, qindices_sharded)

# Next rotation set.
rotset = jacobi_next_rotation_set(rotset)

return (Apcols.array, Aqcols.array, Vpcols.array, Vqcols.array)


Expand Down Expand Up @@ -301,50 +286,6 @@ def permute_pq_indices(
return (np.where(rotset_permute_mask, pindices, qindices), np.where(rotset_permute_mask, qindices, pindices))


def tile_rotate_sorted_columns(
pcols: TileShardedArray, qcols: TileShardedArray, rotset: NDArray[np.uint32]
) -> Tuple[TileShardedArray, TileShardedArray]:
"""Rotate columns between tiles using a static `tile_gather`.
The tricky part of this function is to rotate the columns between tiles, but
keep the property p < q, which means taking care of the present sorting permutation applied
as well the next sorting permutation.
"""
assert pcols.shape == qcols.shape
assert pcols.tiles == qcols.tiles
halfN = pcols.shape[0]
N = halfN * 2
# Concat all columns, in order to perform a single gather.
all_cols = TileShardedArray(
jax.lax.concatenate([pcols.array, qcols.array], dimension=0), (*pcols.tiles, *qcols.tiles)
)

# Start with current indices, in the concat representation of columns
pcols_indices = np.arange(0, halfN, dtype=np.int32)
qcols_indices = np.arange(halfN, N, dtype=np.int32)
# First sorting permutation correction.
rotset_permute_mask = rotset[:, 0] < rotset[:, 1]
pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask)

# Rotation of columns between tiles (see Jacobi alg.)
# Roughtly: pcols move to the right, qcols to the left.
pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]])
qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]])
pcols_indices, qcols_indices = pcols_indices_new, qcols_indices_new
assert len(pcols_indices_new) == halfN
assert len(qcols_indices_new) == halfN

# Second sorting permutation correction, using the next rotation set.
rotset = jacobi_next_rotation_set(rotset)
rotset_permute_mask = rotset[:, 0] < rotset[:, 1]
pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask)

# Move columns around + re-split between pcols and qcols.
all_indices = np.concatenate([pcols_indices, qcols_indices])
all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles)
return all_cols_updated[:halfN], all_cols_updated[halfN:]


def ipu_eigh(
x: Array, *, lower: bool = True, symmetrize_input: bool = False, sort_eigenvalues: bool = True, num_iters: int = 1
) -> Tuple[Array, Array]:
Expand Down

0 comments on commit 49a2e49

Please sign in to comment.