Skip to content

Commit

Permalink
Using gather_p primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 5, 2023
1 parent 2cf4fe1 commit a695c5e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 42 deletions.
12 changes: 6 additions & 6 deletions examples/hessenberg_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

Q, R = jax.jit(ipu_hessenberg, backend="ipu")(A)

Q_ = Q.array.copy()
R_ = R.array.copy()
print("R matrix")
Q_ = np.array(Q.array)
R_ = np.array(R.array)
print("\nR matrix")
print(R_)
print("Q matrix (top left 6-by-6 corner)")
print("\nQ matrix")
print(Q_)
print(f"\nDelta: {np.max(np.abs(Q_ @ R_ @ Q_.T - A))}")

print(f"\nReconstruction Delta: {np.max(np.abs(Q_ @ R_ @ Q_.T - A))}")
print("\nQ.T @ Q")
print(Q_.T @ Q_)
89 changes: 53 additions & 36 deletions tessellate_ipu/linalg/tile_linalg_hessenberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr
x: X array.
sdiag: X diagonal sign.
Returns:
Tile sharded Q, RT, sdiag.
Tile sharded Q, R, sdiag.
"""
assert x.shape[0] == x.shape[1]
N = x.shape[0]
Expand All @@ -86,10 +86,10 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr

# TODO: on-device construction of identity
Q = tile_put_sharded(np.identity(N, dtype=x.dtype), Q_tiles)
RT = tile_put_sharded(x.T, R_tiles)
R = tile_put_sharded(x, R_tiles)
# Replicate once on all tiles. Faster then for the looping.
sdiag_full = tile_put_replicated(xsdiag, R_tiles)
return Q, RT, sdiag_full
sdiag_full = tile_put_replicated(xsdiag.T, R_tiles)
return Q, R, sdiag_full


# Heavily based on ipu_qr_iterations in tile_linalg_qr.py
Expand All @@ -104,48 +104,65 @@ def ipu_hessenberg_body(
i: int, carry: Tuple[TileShardedArray, TileShardedArray, TileShardedArray]
) -> Tuple[TileShardedArray, TileShardedArray, TileShardedArray]:

Q, RT, sdiag_full = carry
Q, R, sdiag_full = carry

# Rcol_array = RT.array[i].reshape(1, -1)
# sdiag_array = sdiag_full.array[i].reshape(1, -1)
dim_numbers = jax.lax.GatherDimensionNumbers(offset_dims=tuple(), collapsed_slice_dims=(0,), start_index_map=(0,))

# # These too are very expensive
# # Rcol_array = jnp.take_along_axis(RT.array, jnp.array([[i]],dtype=np.int32), axis=0)
# # sdiag_array = jnp.take_along_axis(sdiag_full.array, jnp.array([[i]],dtype=np.int32), axis=0)
i_rep = tile_put_replicated(jax.numpy.array([[i]], dtype=np.uint32), R.tiles)

# # Rcol_array = jnp.take(RT.array, i, axis=0).reshape(1,-1)
# # sdiag_array = jnp.take(sdiag_full.array, i, axis=0).reshape(1,-1)

# # Rcol and sdiag are put on an arbitrarily picked tile
# correction_vector_tile = 736
# Rcol = tile_put_sharded(Rcol_array, [correction_vector_tile])
# sdiag = tile_put_sharded(sdiag_array, [correction_vector_tile])
Rcol = tile_map(
jax.lax.gather_p,
R,
i_rep,
dimension_numbers=dim_numbers,
slice_sizes=(1,),
mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
) # => TileShardedArray() (Num_tiles, 1)

Rcol_replicated = tile_put_replicated(Rcol.array, tiles=[736]) # type:ignore

sdiag = tile_map(
jax.lax.gather_p,
sdiag_full,
i_rep,
dimension_numbers=dim_numbers,
slice_sizes=(1,),
mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
) # => TileShardedArray() (Num_tiles, 1)

sdiag_rep = tile_put_replicated(sdiag.array, Rcol_replicated.tiles) # type:ignore

# start_idx = (i // 2) * 2
start_idx = 0

start_idxQ = tile_put_replicated(start_idx, Q.tiles)
start_idxR = tile_put_replicated(start_idx, RT.tiles)
start_idxR = tile_put_replicated(start_idx, R.tiles)

# Alternative: we pass the whole RT and sdiag; then we extract the result from the i-th tile

# Correction vector. Computed
# v, vrescale = tile_map(
# hessenberg_correction_vector_p, Rcol, sdiag, tile_put_replicated(i + 1, Rcol.tiles)
# ) # type:ignore
v, vrescale = tile_map(
hessenberg_correction_vector_p, RT, sdiag_full, tile_put_replicated(i + 1, RT.tiles)
hessenberg_correction_vector_p, Rcol_replicated, sdiag_rep, tile_put_replicated(i + 1, Rcol_replicated.tiles)
) # type:ignore
# v, vrescale = tile_map(
# hessenberg_correction_vector_p, RT, sdiag_full, tile_put_replicated(i + 1, RT.tiles)
# ) # type:ignore

# This compiles
# vi = tile_gather(v.array[i], [0], [0])
# vi = tile_gather(v.array[i], [0], [0]

# Replicate to all Q and R tiles.
vQ = tile_put_replicated(v.array[i], Q.tiles) # 0
vR = tile_put_replicated(v.array[i], RT.tiles) # 0
vQ = tile_put_replicated(v.array, Q.tiles) # 0
vR = tile_put_replicated(v.array, R.tiles) # 0
# v normalization factor to pass to householder update.
vrescaleQ = tile_put_replicated(vrescale.array[i], Q.tiles) # 0
vrescaleR = tile_put_replicated(vrescale.array[i], RT.tiles) # 0
vrescaleQ = tile_put_replicated(vrescale.array, Q.tiles) # 0
vrescaleR = tile_put_replicated(vrescale.array, R.tiles) # 0

# Alternative using tile_gather
# vQ = tile_gather(v, [i]*len(Q.tiles), list(Q.tiles), copy=False) # 0
Expand All @@ -154,6 +171,8 @@ def ipu_hessenberg_body(
# vrescaleQ = tile_gather(vrescale, [i]*len(Q.tiles), Q.tiles) # 0
# vrescaleR = tile_gather(vrescale, [i]*len(RT.tiles), RT.tiles) #

RT = tile_put_sharded(R.array.T, R.tiles)

# w = R^T @ v
w = tile_map(
# dot_product1d_indexed_p, vR, RT, start_idxR
Expand Down Expand Up @@ -193,13 +212,11 @@ def ipu_hessenberg_body(
hessenberg_householder_row_update_p, R, vR, w, vrescaleR, start_idxR # type:ignore
)

RT = tile_put_sharded(R.array.T, R.tiles)

return (Q, RT, sdiag_full)
return (Q, R, sdiag_full)


def ipu_hessenberg_iterations(
Q: TileShardedArray, RT: TileShardedArray, sdiag_full: TileShardedArray
Q: TileShardedArray, R: TileShardedArray, sdiag_full: TileShardedArray
) -> Tuple[TileShardedArray, TileShardedArray]:
"""IPU Hessenberg algorithm iterations.
Expand All @@ -210,12 +227,12 @@ def ipu_hessenberg_iterations(
Returns:
(Q, RT) after N-2 iterations.
"""
assert len(Q) == len(RT)
assert len(Q) == len(R)
N = len(Q)

Q, RT, sdiag_full = jax.lax.fori_loop(0, N - 2, ipu_hessenberg_body, (Q, RT, sdiag_full))
Q, R, sdiag_full = jax.lax.fori_loop(0, N - 2, ipu_hessenberg_body, (Q, R, sdiag_full))

return (Q, RT)
return (Q, R)


def ipu_hessenberg(x: Array) -> Tuple[Array, Array]:
Expand All @@ -230,6 +247,6 @@ def ipu_hessenberg(x: Array) -> Tuple[Array, Array]:
Q, R^T matrices (as tile sharded arrays).
"""
# Initialize Q, RT, sdiag.
Q, RT, sdiag_full = ipu_hessenberg_shard_inputs(x, jax.numpy.sign(jax.numpy.diag(x)))
Q, R, sdiag_full = ipu_hessenberg_shard_inputs(x, jax.numpy.sign(jax.numpy.diag(x)))
# IPU QR iterations.
return ipu_hessenberg_iterations(Q, RT, sdiag_full)
return ipu_hessenberg_iterations(Q, R, sdiag_full)

0 comments on commit a695c5e

Please sign in to comment.