diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index a0d9559cb..8fb26b204 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -1,5 +1,5 @@ """ -QR decomposition of (distributed) 2-D ``DNDarray``s. +QR decomposition of ``DNDarray``s. """ import collections @@ -24,16 +24,19 @@ def qr( Factor the matrix ``A`` as *QR*, where ``Q`` is orthonormal and ``R`` is upper-triangular. If ``mode = "reduced``, function returns ``QR(Q=Q, R=R)``, if ``mode = "r"`` function returns ``QR(Q=None, R=R)`` + This function also works for batches of matrices; in this case, the last two dimensions of the input array are considered as the matrix dimensions. + The output arrays have the same leading batch dimensions as the input array. + Parameters ---------- - A : DNDarray of shape (M, N) - Array which will be decomposed. So far only 2D arrays with datatype float32 or float64 are supported - For split=0, the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns. + A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case + Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported + For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns. mode : str, optional - default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N), respectively. + default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified. "r" returns only R, with dimensions (min(M,N), N). procs_to_merge : int, optional - This parameter is only relevant for split=0 and determines the number of processes to be merged at one step during the so-called TS-QR algorithm. + This parameter is only relevant for split=0 (-2, in the batched case) and determines the number of processes to be merged at one step during the so-called TS-QR algorithm. The default is 2. Higher choices might be faster, but will probably result in higher memory consumption. 0 corresponds to merging all processes at once. We only recommend to modify this parameter if you are familiar with the TS-QR algorithm (see the references below). @@ -49,7 +52,7 @@ def qr( Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage. Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on - the backend. For split=0, tall-skinny QR (TS-QR) is implemented, while for split=1 a block-wise version of stabilized Gram-Schmidt orthogonalization is used. + the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used. References ----------- @@ -87,65 +90,53 @@ def qr( if procs_to_merge == 0: procs_to_merge = A.comm.size - if A.ndim != 2: - raise ValueError( - f"Array 'A' must be 2 dimensional, buts has {A.ndim} dimensions. \n Please open an issue on GitHub if you require QR for batches of matrices similar to PyTorch." - ) if A.dtype not in [float32, float64]: raise TypeError(f"Array 'A' must have a datatype of float32 or float64, but has {A.dtype}") QR = collections.namedtuple("QR", "Q, R") - if not A.is_distributed(): + if not A.is_distributed() or A.split < A.ndim - 2: # handle the case of a single process or split=None: just PyTorch QR Q, R = torch.linalg.qr(A.larray, mode=mode) - R = DNDarray( - R, - gshape=R.shape, - dtype=A.dtype, - split=A.split, - device=A.device, - comm=A.comm, - balanced=True, - ) + R = factories.array(R, is_split=A.split) if mode == "reduced": - Q = DNDarray( - Q, - gshape=Q.shape, - dtype=A.dtype, - split=A.split, - device=A.device, - comm=A.comm, - balanced=True, - ) + Q = factories.array(Q, is_split=A.split) else: Q = None return QR(Q, R) - if A.split == 1: + if A.split == A.ndim - 1: # handle the case that A is split along the columns # here, we apply a block-wise version of (stabilized) Gram-Schmidt orthogonalization # instead of orthogonalizing each column of A individually, we orthogonalize blocks of columns (i.e. the local arrays) at once - lshapes = A.lshape_map[:, 1] + lshapes = A.lshape_map[:, -1] lshapes_cum = torch.cumsum(lshapes, 0) nprocs = A.comm.size - if A.shape[0] >= A.shape[1]: + if A.shape[-2] >= A.shape[-1]: last_row_reached = nprocs - k = A.shape[1] + k = A.shape[-1] else: - last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[0]))[0] - k = A.shape[0] + last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[-2]))[0] + k = A.shape[-2] if mode == "reduced": - Q = factories.zeros(A.shape, dtype=A.dtype, split=1, device=A.device, comm=A.comm) + Q = factories.zeros( + A.shape, dtype=A.dtype, split=A.ndim - 1, device=A.device, comm=A.comm + ) - R = factories.zeros((k, A.shape[1]), dtype=A.dtype, split=1, device=A.device, comm=A.comm) + R = factories.zeros( + (*A.shape[:-2], k, A.shape[-1]), + dtype=A.dtype, + split=A.ndim - 1, + device=A.device, + comm=A.comm, + ) R_shapes = torch.hstack( [ torch.zeros(1, dtype=torch.int32, device=A.device.torch_device), - torch.cumsum(R.lshape_map[:, 1], 0), + torch.cumsum(R.lshape_map[:, -1], 0), ] ) @@ -156,9 +147,9 @@ def qr( # this corresponds to the loop over all columns in classical Gram-Schmidt if i < nprocs - 1: - k_loc_i = min(A.shape[0], A.lshape_map[i, 1]) + k_loc_i = min(A.shape[-2], A.lshape_map[i, -1]) Q_buf = torch.zeros( - (A.shape[0], k_loc_i), dtype=A.larray.dtype, device=A.device.torch_device + (*A.shape[:-1], k_loc_i), dtype=A.larray.dtype, device=A.device.torch_device ) if A.comm.rank == i: @@ -168,8 +159,8 @@ def qr( Q_buf = Q_curr if mode == "reduced": Q.larray = Q_curr - r_size = R.larray[R_shapes[i] : R_shapes[i + 1], :].shape[0] - R.larray[R_shapes[i] : R_shapes[i + 1], :] = R_loc[:r_size, :] + r_size = R.larray[..., R_shapes[i] : R_shapes[i + 1], :].shape[-2] + R.larray[..., R_shapes[i] : R_shapes[i + 1], :] = R_loc[..., :r_size, :] if i < nprocs - 1: # broadcast the orthogonalized block of columns to all other processes @@ -177,22 +168,22 @@ def qr( if A.comm.rank > i: # subtract the contribution of the current block of columns from the remaining columns - R_loc = Q_buf.T @ A_columns + R_loc = torch.transpose(Q_buf, -2, -1) @ A_columns A_columns -= Q_buf @ R_loc - r_size = R.larray[R_shapes[i] : R_shapes[i + 1], :].shape[0] - R.larray[R_shapes[i] : R_shapes[i + 1], :] = R_loc[:r_size, :] + r_size = R.larray[..., R_shapes[i] : R_shapes[i + 1], :].shape[-2] + R.larray[..., R_shapes[i] : R_shapes[i + 1], :] = R_loc[..., :r_size, :] if mode == "reduced": - Q = Q[:, :k].balance() + Q = Q[..., :, :k].balance() else: Q = None return QR(Q, R) - if A.split == 0: + if A.split == A.ndim - 2: # implementation of TS-QR for split = 0 # check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data) - if A.lshape_map[:, 0].max().item() < A.shape[1]: + if A.lshape_map[:, -2].max().item() < A.shape[-1]: raise ValueError( "A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." ) @@ -209,10 +200,10 @@ def qr( while len(current_procs) > 1: if A.comm.rank in current_procs and local_comm.size > 1: # create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes - shapes_R_loc = local_comm.gather(R_loc.shape[0], root=0) + shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) if local_comm.rank == 0: gathered_R_loc = torch.zeros( - (sum(shapes_R_loc), R_loc.shape[1]), + (*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), device=R_loc.device, dtype=R_loc.dtype, ) @@ -225,7 +216,7 @@ def qr( counts = None displs = None # gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes - local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=0) + local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) # perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc if local_comm.rank == 0: previous_shape = R_loc.shape @@ -242,7 +233,7 @@ def qr( dtype=R_loc.dtype, ) # scatter the Q_buf to all processes of the process group - local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=0) + local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2) del gathered_R_loc, Q_buf # for each process in the current processes, broadcast the scattered_Q_buf of this process @@ -282,7 +273,7 @@ def qr( leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) level += 1 # broadcast the final R_loc to all processes - R_gshape = (A.shape[1], A.shape[1]) + R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) if A.comm.rank != 0: R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) A.comm.Bcast(R_loc, root=0) @@ -302,7 +293,7 @@ def qr( Q_loc, gshape=A.shape, dtype=A.dtype, - split=0, + split=A.split, device=A.device, comm=A.comm, balanced=True, diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index 6de9e091d..08d130486 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -124,13 +124,40 @@ def test_qr_split0(self): ) ) + def test_batched_qr_splitNone(self): + # two batch dimensions, float64 data type, "split = None" (split batch axis) + x = ht.random.rand(2, 2 * ht.MPI_WORLD.size, 10, 9, dtype=ht.float32, split=1) + _, r = ht.linalg.qr(x, mode="r") + self.assertEqual(r.shape, (2, 2 * ht.MPI_WORLD.size, 9, 9)) + self.assertEqual(r.split, 1) + + def test_batched_qr_split1(self): + # two batch dimensions, float64 data type, "split = 1" (last dimension) + ht.random.seed(0) + x = ht.random.rand(3, 2, 50, ht.MPI_WORLD.size * 5 + 3, dtype=ht.float64, split=3) + q, r = ht.linalg.qr(x) + batched_id = ht.stack([ht.eye(q.shape[3], dtype=ht.float64) for _ in range(6)]).reshape( + 3, 2, q.shape[3], q.shape[3] + ) + + self.assertTrue( + ht.allclose(q.transpose([0, 1, 3, 2]) @ q, batched_id, atol=1e-6, rtol=1e-6) + ) + self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6)) + + def test_batched_qr_split0(self): + # one batch dimension, float32 data type, "split = 0" (second last dimension) + x = ht.random.randn(8, ht.MPI_WORLD.size * 10 + 3, 9, dtype=ht.float32, split=1) + q, r = ht.linalg.qr(x) + batched_id = ht.stack([ht.eye(q.shape[2], dtype=ht.float32) for _ in range(q.shape[0])]) + + self.assertTrue(ht.allclose(q.transpose([0, 2, 1]) @ q, batched_id, atol=1e-3, rtol=1e-3)) + self.assertTrue(ht.allclose(q @ r, x, atol=1e-3, rtol=1e-3)) + def test_wronginputs(self): # test wrong input type with self.assertRaises(TypeError): ht.linalg.qr([1, 2, 3]) - # test too many input dimensions - with self.assertRaises(ValueError): - ht.linalg.qr(ht.zeros((10, 10, 10))) # wrong data type for mode with self.assertRaises(TypeError): ht.linalg.qr(ht.zeros((10, 10)), mode=1) @@ -148,9 +175,6 @@ def test_wronginputs(self): # test wrong procs_to_merge with self.assertRaises(ValueError): ht.linalg.qr(ht.zeros((10, 10)), procs_to_merge=1) - # test wrong shape - with self.assertRaises(ValueError): - ht.linalg.qr(ht.zeros((10, 10, 10))) # test wrong dtype with self.assertRaises(TypeError): ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32))