Skip to content

Commit

Permalink
Support batched QR decomposition (#1720)
Browse files Browse the repository at this point in the history
* Batched QR, update of unit tests is missing so far

* removed old tests that threw errors for batched inputs

* final debugging of tests

* added changes to docs

* dummy change for benchmarking run

* Update heat/core/linalg/qr.py

Co-authored-by: Claudia Comito <[email protected]>

* Update qr.py

remove dead code

---------

Co-authored-by: Hoppe <[email protected]>
Co-authored-by: Claudia Comito <[email protected]>
  • Loading branch information
3 people authored Dec 2, 2024
1 parent c23428e commit e6499cf
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 61 deletions.
101 changes: 46 additions & 55 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
QR decomposition of (distributed) 2-D ``DNDarray``s.
QR decomposition of ``DNDarray``s.
"""

import collections
Expand All @@ -24,16 +24,19 @@ def qr(
Factor the matrix ``A`` as *QR*, where ``Q`` is orthonormal and ``R`` is upper-triangular.
If ``mode = "reduced``, function returns ``QR(Q=Q, R=R)``, if ``mode = "r"`` function returns ``QR(Q=None, R=R)``
This function also works for batches of matrices; in this case, the last two dimensions of the input array are considered as the matrix dimensions.
The output arrays have the same leading batch dimensions as the input array.
Parameters
----------
A : DNDarray of shape (M, N)
Array which will be decomposed. So far only 2D arrays with datatype float32 or float64 are supported
For split=0, the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns.
A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case
Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported
For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns.
mode : str, optional
default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N), respectively.
default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified.
"r" returns only R, with dimensions (min(M,N), N).
procs_to_merge : int, optional
This parameter is only relevant for split=0 and determines the number of processes to be merged at one step during the so-called TS-QR algorithm.
This parameter is only relevant for split=0 (-2, in the batched case) and determines the number of processes to be merged at one step during the so-called TS-QR algorithm.
The default is 2. Higher choices might be faster, but will probably result in higher memory consumption. 0 corresponds to merging all processes at once.
We only recommend to modify this parameter if you are familiar with the TS-QR algorithm (see the references below).
Expand All @@ -49,7 +52,7 @@ def qr(
Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage.
Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on
the backend. For split=0, tall-skinny QR (TS-QR) is implemented, while for split=1 a block-wise version of stabilized Gram-Schmidt orthogonalization is used.
the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used.
References
-----------
Expand Down Expand Up @@ -87,65 +90,53 @@ def qr(
if procs_to_merge == 0:
procs_to_merge = A.comm.size

if A.ndim != 2:
raise ValueError(
f"Array 'A' must be 2 dimensional, buts has {A.ndim} dimensions. \n Please open an issue on GitHub if you require QR for batches of matrices similar to PyTorch."
)
if A.dtype not in [float32, float64]:
raise TypeError(f"Array 'A' must have a datatype of float32 or float64, but has {A.dtype}")

QR = collections.namedtuple("QR", "Q, R")

if not A.is_distributed():
if not A.is_distributed() or A.split < A.ndim - 2:
# handle the case of a single process or split=None: just PyTorch QR
Q, R = torch.linalg.qr(A.larray, mode=mode)
R = DNDarray(
R,
gshape=R.shape,
dtype=A.dtype,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
)
R = factories.array(R, is_split=A.split)
if mode == "reduced":
Q = DNDarray(
Q,
gshape=Q.shape,
dtype=A.dtype,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
)
Q = factories.array(Q, is_split=A.split)
else:
Q = None
return QR(Q, R)

if A.split == 1:
if A.split == A.ndim - 1:
# handle the case that A is split along the columns
# here, we apply a block-wise version of (stabilized) Gram-Schmidt orthogonalization
# instead of orthogonalizing each column of A individually, we orthogonalize blocks of columns (i.e. the local arrays) at once

lshapes = A.lshape_map[:, 1]
lshapes = A.lshape_map[:, -1]
lshapes_cum = torch.cumsum(lshapes, 0)
nprocs = A.comm.size

if A.shape[0] >= A.shape[1]:
if A.shape[-2] >= A.shape[-1]:
last_row_reached = nprocs
k = A.shape[1]
k = A.shape[-1]
else:
last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[0]))[0]
k = A.shape[0]
last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[-2]))[0]
k = A.shape[-2]

if mode == "reduced":
Q = factories.zeros(A.shape, dtype=A.dtype, split=1, device=A.device, comm=A.comm)
Q = factories.zeros(
A.shape, dtype=A.dtype, split=A.ndim - 1, device=A.device, comm=A.comm
)

R = factories.zeros((k, A.shape[1]), dtype=A.dtype, split=1, device=A.device, comm=A.comm)
R = factories.zeros(
(*A.shape[:-2], k, A.shape[-1]),
dtype=A.dtype,
split=A.ndim - 1,
device=A.device,
comm=A.comm,
)
R_shapes = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=A.device.torch_device),
torch.cumsum(R.lshape_map[:, 1], 0),
torch.cumsum(R.lshape_map[:, -1], 0),
]
)

Expand All @@ -156,9 +147,9 @@ def qr(
# this corresponds to the loop over all columns in classical Gram-Schmidt

if i < nprocs - 1:
k_loc_i = min(A.shape[0], A.lshape_map[i, 1])
k_loc_i = min(A.shape[-2], A.lshape_map[i, -1])
Q_buf = torch.zeros(
(A.shape[0], k_loc_i), dtype=A.larray.dtype, device=A.device.torch_device
(*A.shape[:-1], k_loc_i), dtype=A.larray.dtype, device=A.device.torch_device
)

if A.comm.rank == i:
Expand All @@ -168,31 +159,31 @@ def qr(
Q_buf = Q_curr
if mode == "reduced":
Q.larray = Q_curr
r_size = R.larray[R_shapes[i] : R_shapes[i + 1], :].shape[0]
R.larray[R_shapes[i] : R_shapes[i + 1], :] = R_loc[:r_size, :]
r_size = R.larray[..., R_shapes[i] : R_shapes[i + 1], :].shape[-2]
R.larray[..., R_shapes[i] : R_shapes[i + 1], :] = R_loc[..., :r_size, :]

if i < nprocs - 1:
# broadcast the orthogonalized block of columns to all other processes
A.comm.Bcast(Q_buf, root=i)

if A.comm.rank > i:
# subtract the contribution of the current block of columns from the remaining columns
R_loc = Q_buf.T @ A_columns
R_loc = torch.transpose(Q_buf, -2, -1) @ A_columns
A_columns -= Q_buf @ R_loc
r_size = R.larray[R_shapes[i] : R_shapes[i + 1], :].shape[0]
R.larray[R_shapes[i] : R_shapes[i + 1], :] = R_loc[:r_size, :]
r_size = R.larray[..., R_shapes[i] : R_shapes[i + 1], :].shape[-2]
R.larray[..., R_shapes[i] : R_shapes[i + 1], :] = R_loc[..., :r_size, :]

if mode == "reduced":
Q = Q[:, :k].balance()
Q = Q[..., :, :k].balance()
else:
Q = None

return QR(Q, R)

if A.split == 0:
if A.split == A.ndim - 2:
# implementation of TS-QR for split = 0
# check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data)
if A.lshape_map[:, 0].max().item() < A.shape[1]:
if A.lshape_map[:, -2].max().item() < A.shape[-1]:
raise ValueError(
"A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub."
)
Expand All @@ -209,10 +200,10 @@ def qr(
while len(current_procs) > 1:
if A.comm.rank in current_procs and local_comm.size > 1:
# create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes
shapes_R_loc = local_comm.gather(R_loc.shape[0], root=0)
shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0)
if local_comm.rank == 0:
gathered_R_loc = torch.zeros(
(sum(shapes_R_loc), R_loc.shape[1]),
(*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]),
device=R_loc.device,
dtype=R_loc.dtype,
)
Expand All @@ -225,7 +216,7 @@ def qr(
counts = None
displs = None
# gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=0)
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2)
# perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc
if local_comm.rank == 0:
previous_shape = R_loc.shape
Expand All @@ -242,7 +233,7 @@ def qr(
dtype=R_loc.dtype,
)
# scatter the Q_buf to all processes of the process group
local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=0)
local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2)
del gathered_R_loc, Q_buf

# for each process in the current processes, broadcast the scattered_Q_buf of this process
Expand Down Expand Up @@ -282,7 +273,7 @@ def qr(
leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank)
level += 1
# broadcast the final R_loc to all processes
R_gshape = (A.shape[1], A.shape[1])
R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1])
if A.comm.rank != 0:
R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device)
A.comm.Bcast(R_loc, root=0)
Expand All @@ -302,7 +293,7 @@ def qr(
Q_loc,
gshape=A.shape,
dtype=A.dtype,
split=0,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
Expand Down
36 changes: 30 additions & 6 deletions heat/core/linalg/tests/test_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,40 @@ def test_qr_split0(self):
)
)

def test_batched_qr_splitNone(self):
# two batch dimensions, float64 data type, "split = None" (split batch axis)
x = ht.random.rand(2, 2 * ht.MPI_WORLD.size, 10, 9, dtype=ht.float32, split=1)
_, r = ht.linalg.qr(x, mode="r")
self.assertEqual(r.shape, (2, 2 * ht.MPI_WORLD.size, 9, 9))
self.assertEqual(r.split, 1)

def test_batched_qr_split1(self):
# two batch dimensions, float64 data type, "split = 1" (last dimension)
ht.random.seed(0)
x = ht.random.rand(3, 2, 50, ht.MPI_WORLD.size * 5 + 3, dtype=ht.float64, split=3)
q, r = ht.linalg.qr(x)
batched_id = ht.stack([ht.eye(q.shape[3], dtype=ht.float64) for _ in range(6)]).reshape(
3, 2, q.shape[3], q.shape[3]
)

self.assertTrue(
ht.allclose(q.transpose([0, 1, 3, 2]) @ q, batched_id, atol=1e-6, rtol=1e-6)
)
self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6))

def test_batched_qr_split0(self):
# one batch dimension, float32 data type, "split = 0" (second last dimension)
x = ht.random.randn(8, ht.MPI_WORLD.size * 10 + 3, 9, dtype=ht.float32, split=1)
q, r = ht.linalg.qr(x)
batched_id = ht.stack([ht.eye(q.shape[2], dtype=ht.float32) for _ in range(q.shape[0])])

self.assertTrue(ht.allclose(q.transpose([0, 2, 1]) @ q, batched_id, atol=1e-3, rtol=1e-3))
self.assertTrue(ht.allclose(q @ r, x, atol=1e-3, rtol=1e-3))

def test_wronginputs(self):
# test wrong input type
with self.assertRaises(TypeError):
ht.linalg.qr([1, 2, 3])
# test too many input dimensions
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10, 10)))
# wrong data type for mode
with self.assertRaises(TypeError):
ht.linalg.qr(ht.zeros((10, 10)), mode=1)
Expand All @@ -148,9 +175,6 @@ def test_wronginputs(self):
# test wrong procs_to_merge
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10)), procs_to_merge=1)
# test wrong shape
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10, 10)))
# test wrong dtype
with self.assertRaises(TypeError):
ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32))
Expand Down

0 comments on commit e6499cf

Please sign in to comment.