diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 8fb26b204..8544c132f 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -7,6 +7,7 @@ from typing import Tuple from ..dndarray import DNDarray +from ..manipulations import concatenate from .. import factories from .. import communication from ..types import float32, float64 @@ -31,7 +32,6 @@ def qr( ---------- A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported - For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns. mode : str, optional default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified. "r" returns only R, with dimensions (min(M,N), N). @@ -46,13 +46,17 @@ def qr( - If ``A`` is distributed along the columns (A.split = 1), so will be ``Q`` and ``R``. - - If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`, but ``R`` won't be distributed, i.e. `R. split = None` and a full copy of ``R`` will be stored on each process. + - If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`. ``R`` won't be distributed, i.e. `R. split = None`, if ``A`` is tall-skinny, i.e., if + the largest local chunk of data of ``A`` has at least as many rows as columns. Otherwise, ``R`` will be distributed along the rows as well, i.e., `R.split = 0`. Note that the argument `calc_q` allowed in earlier Heat versions is no longer supported; `calc_q = False` is equivalent to `mode = "r"`. Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage. Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on - the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used. + the backend. Both cases split=0 and split=1 build on a column-block-wise version of stabilized Gram-Schmidt orthogonalization. + For split=1 (-1, in the batched case), this is directly applied to the local arrays of the input array. + For split=0, a tall-skinny QR (TS-QR) is implemented for the case of tall-skinny matrices (i.e., the largest local chunk of data has at least as many rows as columns), + and extended to non tall-skinny matrices by applying a block-wise version of stabilized Gram-Schmidt orthogonalization. References ----------- @@ -181,121 +185,171 @@ def qr( return QR(Q, R) if A.split == A.ndim - 2: - # implementation of TS-QR for split = 0 - # check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data) - if A.lshape_map[:, -2].max().item() < A.shape[-1]: - raise ValueError( - "A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." + # check that data distribution is reasonable for TS-QR + # we regard a matrix with split = 0 as suitable for TS-QR is largest local chunk of data has at least as many rows as columns + biggest_number_of_local_rows = A.lshape_map[:, -2].max().item() + if biggest_number_of_local_rows < A.shape[-1]: + column_idx = torch.cumsum(A.lshape_map[:, -2], 0) + column_idx = column_idx[column_idx < A.shape[-1]] + column_idx = torch.cat( + ( + torch.tensor([0], device=column_idx.device), + column_idx, + torch.tensor([A.shape[-1]], device=column_idx.device), + ) ) + A_copy = A.copy() + R = A.copy() + # Block-wise Gram-Schmidt orthogonalization, applied to groups of columns + offset = 1 if A.shape[-1] <= A.shape[-2] else 2 + for k in range(len(column_idx) - offset): + # since we only consider a group of columns, TS QR is applied to a tall-skinny matrix + Qnew, Rnew = qr( + A_copy[..., :, column_idx[k] : column_idx[k + 1]], + mode="reduced", + procs_to_merge=procs_to_merge, + ) - current_procs = [i for i in range(A.comm.size)] - current_comm = A.comm - local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) - Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode) - R_loc = R_loc.contiguous() # required for all the communication ops lateron - if mode == "reduced": - leave_comm = current_comm.Split(current_comm.rank, A.comm.rank) - - level = 1 - while len(current_procs) > 1: - if A.comm.rank in current_procs and local_comm.size > 1: - # create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes - shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) - if local_comm.rank == 0: - gathered_R_loc = torch.zeros( - (*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), - device=R_loc.device, - dtype=R_loc.dtype, + # usual update of the remaining columns + if R.comm.rank == k: + R.larray[ + ..., + : (column_idx[k + 1] - column_idx[k]), + column_idx[k] : column_idx[k + 1], + ] = Rnew.larray + if R.comm.rank > k: + R.larray[..., :, column_idx[k] : column_idx[k + 1]] *= 0 + if k < len(column_idx) - 2: + coeffs = ( + torch.transpose(Qnew.larray, -2, -1) + @ A_copy.larray[..., :, column_idx[k + 1] :] ) - counts = list(shapes_R_loc) - displs = torch.cumsum( - torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0 - ).tolist()[:-1] - else: - gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) - counts = None - displs = None - # gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes - local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) - # perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc - if local_comm.rank == 0: - previous_shape = R_loc.shape - Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode) - R_loc = R_loc.contiguous() - else: - Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) + R.comm.Allreduce(communication.MPI.IN_PLACE, coeffs) + if R.comm.rank == k: + R.larray[..., :, column_idx[k + 1] :] = coeffs + A_copy.larray[..., :, column_idx[k + 1] :] -= Qnew.larray @ coeffs if mode == "reduced": - if local_comm.rank == 0: - Q_buf = Q_buf.contiguous() - scattered_Q_buf = torch.empty( - R_loc.shape if local_comm.rank != 0 else previous_shape, - device=R_loc.device, - dtype=R_loc.dtype, - ) - # scatter the Q_buf to all processes of the process group - local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2) - del gathered_R_loc, Q_buf + Q = Qnew if k == 0 else concatenate((Q, Qnew), axis=-1) + if A.shape[-1] < A.shape[-2]: + R = R[..., : A.shape[-1], :].balance() + if mode == "reduced": + return QR(Q, R) + else: + return QR(None, R) - # for each process in the current processes, broadcast the scattered_Q_buf of this process - # to all leaves (i.e. all original processes that merge to the current process) - if mode == "reduced" and leave_comm.size > 1: + else: + # in this case the input is tall-skinny and we apply the TS-QR algorithm + # it follows the implementation of TS-QR for split = 0 + current_procs = [i for i in range(A.comm.size)] + current_comm = A.comm + local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) + Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode) + R_loc = R_loc.contiguous() # required for all the communication ops lateron + if mode == "reduced": + leave_comm = current_comm.Split(current_comm.rank, A.comm.rank) + + level = 1 + while len(current_procs) > 1: + if A.comm.rank in current_procs and local_comm.size > 1: + # create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes + shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) + if local_comm.rank == 0: + gathered_R_loc = torch.zeros( + (*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), + device=R_loc.device, + dtype=R_loc.dtype, + ) + counts = list(shapes_R_loc) + displs = torch.cumsum( + torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0 + ).tolist()[:-1] + else: + gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) + counts = None + displs = None + # gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes + local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) + # perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc + if local_comm.rank == 0: + previous_shape = R_loc.shape + Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode) + R_loc = R_loc.contiguous() + else: + Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) + if mode == "reduced": + if local_comm.rank == 0: + Q_buf = Q_buf.contiguous() + scattered_Q_buf = torch.empty( + R_loc.shape if local_comm.rank != 0 else previous_shape, + device=R_loc.device, + dtype=R_loc.dtype, + ) + # scatter the Q_buf to all processes of the process group + local_comm.Scatterv( + (Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2 + ) + del gathered_R_loc, Q_buf + + # for each process in the current processes, broadcast the scattered_Q_buf of this process + # to all leaves (i.e. all original processes that merge to the current process) + if mode == "reduced" and leave_comm.size > 1: + try: + scattered_Q_buf_shape = scattered_Q_buf.shape + except UnboundLocalError: + scattered_Q_buf_shape = None + scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0) + if scattered_Q_buf_shape is not None: + # this is needed to ensure that only those Q_loc get updates that are actually part of the current process group + if leave_comm.rank != 0: + scattered_Q_buf = torch.empty( + scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype + ) + leave_comm.Bcast(scattered_Q_buf, root=0) + # update the local Q_loc by multiplying it with the scattered_Q_buf try: - scattered_Q_buf_shape = scattered_Q_buf.shape + Q_loc = Q_loc @ scattered_Q_buf + del scattered_Q_buf except UnboundLocalError: - scattered_Q_buf_shape = None - scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0) - if scattered_Q_buf_shape is not None: - # this is needed to ensure that only those Q_loc get updates that are actually part of the current process group - if leave_comm.rank != 0: - scattered_Q_buf = torch.empty( - scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype + pass + + # update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering + current_procs = [ + current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0 + ] + if len(current_procs) > 1: + new_group = A.comm.group.Incl(current_procs) + current_comm = A.comm.Create_group(new_group) + if A.comm.rank in current_procs: + local_comm = communication.MPICommunication( + current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) ) - leave_comm.Bcast(scattered_Q_buf, root=0) - # update the local Q_loc by multiplying it with the scattered_Q_buf - try: - Q_loc = Q_loc @ scattered_Q_buf - del scattered_Q_buf - except UnboundLocalError: - pass - - # update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering - current_procs = [ - current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0 - ] - if len(current_procs) > 1: - new_group = A.comm.group.Incl(current_procs) - current_comm = A.comm.Create_group(new_group) - if A.comm.rank in current_procs: - local_comm = communication.MPICommunication( - current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) - ) - if mode == "reduced": - leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) - level += 1 - # broadcast the final R_loc to all processes - R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) - if A.comm.rank != 0: - R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) - A.comm.Bcast(R_loc, root=0) - R = DNDarray( - R_loc, - gshape=R_gshape, - dtype=A.dtype, - split=None, - device=A.device, - comm=A.comm, - balanced=True, - ) - if mode == "r": - Q = None - else: - Q = DNDarray( - Q_loc, - gshape=A.shape, + if mode == "reduced": + leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) + level += 1 + # broadcast the final R_loc to all processes + R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) + if A.comm.rank != 0: + R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) + A.comm.Bcast(R_loc, root=0) + R = DNDarray( + R_loc, + gshape=R_gshape, dtype=A.dtype, - split=A.split, + split=None, device=A.device, comm=A.comm, balanced=True, ) - return QR(Q, R) + if mode == "r": + Q = None + else: + Q = DNDarray( + Q_loc, + gshape=A.shape, + dtype=A.dtype, + split=A.split, + device=A.device, + comm=A.comm, + balanced=True, + ) + return QR(Q, R) diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index 08d130486..f5c9b25ec 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -76,7 +76,11 @@ def test_qr_split0(self): for procs_to_merge in [0, 2, 3]: for mode in ["reduced", "r"]: # split = 0 can be handeled only for tall skinny matrices s.t. the local chunks are at least square too - for shape in [(40 * ht.MPI_WORLD.size + 1, 40), (40 * ht.MPI_WORLD.size, 20)]: + for shape in [ + (20 * ht.MPI_WORLD.size + 1, 40 * ht.MPI_WORLD.size), + (20 * ht.MPI_WORLD.size, 20 * ht.MPI_WORLD.size), + (40 * ht.MPI_WORLD.size - 1, 20 * ht.MPI_WORLD.size), + ]: for dtype in [ht.float32, ht.float64]: dtypetol = 1e-3 if dtype == ht.float32 else 1e-6 mat = ht.random.randn(*shape, dtype=dtype, split=split) @@ -146,8 +150,11 @@ def test_batched_qr_split1(self): self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6)) def test_batched_qr_split0(self): + ht.random.seed(424242) # one batch dimension, float32 data type, "split = 0" (second last dimension) - x = ht.random.randn(8, ht.MPI_WORLD.size * 10 + 3, 9, dtype=ht.float32, split=1) + x = ht.random.randn( + 8, ht.MPI_WORLD.size * 10 + 3, ht.MPI_WORLD.size * 10 - 1, dtype=ht.float32, split=1 + ) q, r = ht.linalg.qr(x) batched_id = ht.stack([ht.eye(q.shape[2], dtype=ht.float32) for _ in range(q.shape[0])]) @@ -178,7 +185,3 @@ def test_wronginputs(self): # test wrong dtype with self.assertRaises(TypeError): ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32)) - # test wrong shape for split=0 - if ht.MPI_WORLD.size > 1: - with self.assertRaises(ValueError): - ht.linalg.qr(ht.zeros((10, 10), split=0))