Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Householder reduction to Hessenberg form #34

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tessellate_ipu/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from . import tile_linalg_jacobi, tile_linalg_qr
from .tile_linalg_jacobi import ipu_eigh
from .tile_linalg_qr import ipu_qr
from .tile_linalg_hessenberg import ipu_hessenberg
165 changes: 165 additions & 0 deletions tessellate_ipu/linalg/tile_linalg_hessenberg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import os
from typing import Any, Tuple

import jax.lax
import numpy as np
from jax.core import ShapedArray

from tessellate_ipu import TileShardedArray, create_ipu_tile_primitive, tile_map, tile_put_replicated, tile_put_sharded
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets

Array = Any


def get_qr_vertex_gp_filename() -> str:
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_qr_vertex.cpp")


dot_product1d_p = create_ipu_tile_primitive(
"dot_product1d",
"DotProduct1dVertex",
inputs=["x", "y"],
outputs={"partials": ShapedArray((12,), dtype=np.float32)},
constants={
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
inavals[0].size, vector_size=2, num_workers=6, wdtype=np.uint16
)
},
# tmp_space=ShapedArray((12,), dtype=np.float32),
gp_filename=get_qr_vertex_gp_filename(),
perf_estimate=1000,
)

"""Vertex computing QR correction vector.
"""
qr_correction_vector_p = create_ipu_tile_primitive(
"qr_correction_vector",
"QRCorrectionVectorVertex",
inputs=["Rcol", "sdiag"],
outputs={"v": 0, "vrescale": ShapedArray((1,), dtype=np.float32)},
gp_filename=get_qr_vertex_gp_filename(),
perf_estimate=1000,
)

"""Vertex QR HouseHolder performing row inplace update: x -= scale1[0] * scale2[0] * v
"""
qr_householder_row_update_p = create_ipu_tile_primitive(
paolot-gc marked this conversation as resolved.
Show resolved Hide resolved
"qr_householder_row_update",
"QRHouseholderRowUpdateVertex",
inputs=["x", "v", "scale1", "scale2"],
outputs={"x": 0},
constants={
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
inavals[1].size, vector_size=2, wdtype=np.uint16
)
},
gp_filename=get_qr_vertex_gp_filename(),
perf_estimate=1000,
)


def ipu_qr_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArray, TileShardedArray, TileShardedArray]:
paolot-gc marked this conversation as resolved.
Show resolved Hide resolved
"""IPU QR initial sharding of input arrays across IPU tiles.

Args:
x: X array.
sdiag: X diagonal sign.
Returns:
Tile sharded Q, RT, sdiag.
"""
assert x.shape[0] == x.shape[1]
N = x.shape[0]
# Sharding R and Q
Q_tiles = tuple(range(0, N))
R_tiles = tuple(range(N, 2 * N))

# TODO: on-device construction of identity
Q = tile_put_sharded(np.identity(N, dtype=x.dtype), Q_tiles)
RT = tile_put_sharded(x.T, R_tiles)
# Replicate once on all tiles. Faster then for the looping.
sdiag_full = tile_put_replicated(xsdiag, R_tiles)
return Q, RT, sdiag_full


def ipu_hess_iterations(
Q: TileShardedArray, RT: TileShardedArray, sdiag_full: TileShardedArray
) -> Tuple[TileShardedArray, TileShardedArray]:
"""IPU Hessenberg algorithm iterations.

Args:
Q: Initial Q sharded array.
RT: Initial R.T sharded array.
sdiag_full: Diagonal sign (replicated).
Returns:
(Q, RT) after N-1 iterations.
"""
assert len(Q) == len(RT)
N = len(Q)
# Sharding of R and Q on tiles.
Q_tiles = Q.tiles
R_tiles = RT.tiles

for cidx in range( N-2 ):
# From which column to start computation: skipping zeros. Must be a multiple of 2 for proper vectorization.
start_idx = (cidx // 2) * 2
# Extract the proper R column (no tile copy, pure view).
Rcol = RT[cidx]
sdiag = sdiag_full[cidx]
# Correction vector. NOTE: computed on a single tile, changing at every loop.
v, vrescale = tile_map(qr_correction_vector_p, Rcol, sdiag, col_idx=cidx+1) # type:ignore

# Replicate to all Q and R tiles.
vQ = tile_put_replicated(v.array[0], Q_tiles)
vR = tile_put_replicated(v.array[0], R_tiles)
# v normalization factor to pass to householder update.
vrescaleQ = tile_put_replicated(vrescale.array[0], Q_tiles)
vrescaleR = tile_put_replicated(vrescale.array[0], R_tiles)

# Using "smart" slicing to reduce compute to do.
# w = R^T @ v
w = tile_map(dot_product1d_p, vR[:, start_idx:], RT[:, start_idx:]) # this returns size 12 array (6 worker threads)
w = tile_map(jax.lax.reduce_sum_p, w, axes=(0,)) # type:ignore
# Inplace update of R.
RT = tile_map( # type:ignore
qr_householder_row_update_p, RT, vR[:, start_idx:], w, vrescaleR, start_idx=start_idx # type:ignore
)

R = tile_put_sharded(RT.array.T, RT.tiles)
# Using "smart" slicing to reduce compute to do.
# w = R^T @ v
w = tile_map(dot_product1d_p, vR[:, start_idx:], R[:, start_idx:]) # this returns size 12 array (6 worker threads)
w = tile_map(jax.lax.reduce_sum_p, w, axes=(0,)) # type:ignore
# Inplace update of R.
R = tile_map( # type:ignore
qr_householder_row_update_p, R, vR[:, start_idx:], w, vrescaleR, start_idx=start_idx # type:ignore
)

# w = Q @ v
w = tile_map(dot_product1d_p, vQ[:, start_idx:], Q[:, start_idx:])
w = tile_map(jax.lax.reduce_sum_p, w, axes=(0,)) # type:ignore
# Inplace update of Q.
Q = tile_map(
qr_householder_row_update_p, Q, vQ[:, start_idx:], w, vrescaleQ, start_idx=start_idx # type:ignore
)

RT = tile_put_sharded(R.array.T, R.tiles)

return (Q, R)


def ipu_hessenberg(x: Array) -> Tuple[Array, Array]:
"""IPU implementation of the QR algorithm.

This implementation is returing R^T instead of R, as it is more
efficient to store the former while iterating.

Args:
x: Symmetric matrix.
Returns:
Q, R^T matrices (as tile sharded arrays).
"""
# Initialize Q, RT, sdiag.
Q, RT, sdiag_full = ipu_qr_shard_inputs(x, jax.numpy.sign(jax.numpy.diag(x)))
# IPU QR iterations.
return ipu_hess_iterations(Q, RT, sdiag_full)
Loading