Skip to content

Commit

Permalink
Refactoring of QR: stabilized Gram-Schmidt for split=1 and TS-QR for …
Browse files Browse the repository at this point in the history
…split=0 (#1329)

* added file `myqr.py` that contains experiment in direction blockwise stabilized GS for QR (split=1)

* Update __init__.py

* removed `myrandn` since `randn` is fixed meanwhile

* ...

* replaced previour QR implementation by the new one

* ...

* refactoring

* ...

* added some unit tests for the refactored qr with split=1

* ...

* resolved some bugs related to calc_q, calc_r etc.

* removed option overwrite_a

* extended comments in the code

* ...

* removed strange option full_q (should be addressed somewhere else later)

* formatting of warning

* adapted API of new qr to follow the one in numpy

* first version of split=0 (TS-QR), stil with bugs

* fixed some bugs, some others are stil there...

* fixed some bugs: a problem was, e.g., that torch's qr returns non-contiguous q and r even if input was contiguous

* still errors but I dont find them...

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* found bugs in TS-QR, now we should have a running version
TODOS: extend comments (to have maintainable code), check ROCm-support, unittests for wrong inpus/errors

* removed skip from some svd_tests on AMD (they were introduced due to missing QR support)
added tests for catching wrong inputs in QR
slightly extended docs of QR

* Update test_matrixgallery.py

removed skips for AMD devices

* Update linalg.py

modified continuous benchmarks for new qr implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update linalg.py

removed bug in benchmarks

* resolved first parts of the comments in the review

* updated docs of refactored QR since no batches of matrices are supported

* modified docs according to review

* removed "from time import sleep", completed dead end sentence in docs

---------

Co-authored-by: Hoppe <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Claudia Comito <[email protected]>
4 people authored Apr 17, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 86fe8a6 commit 67bc254
Showing 5 changed files with 397 additions and 1,099 deletions.
16 changes: 9 additions & 7 deletions benchmarks/cb/linalg.py
Original file line number Diff line number Diff line change
@@ -16,14 +16,12 @@ def matmul_split_1(a, b):

@monitor()
def qr_split_0(a):
for t in range(2, 3):
qr = a.qr(tiles_per_proc=t)
qr = ht.linalg.qr(a)


@monitor()
def qr_split_1(a):
for t in range(1, 3):
qr = a.qr(tiles_per_proc=t)
qr = ht.linalg.qr(a)


@monitor()
@@ -53,12 +51,16 @@ def run_linalg_benchmarks():
matmul_split_1(a, b)
del a, b

n = int((4000000 // MPI.COMM_WORLD.size) ** 0.5)
m = MPI.COMM_WORLD.size * n
a_0 = ht.random.random((m, n), split=0)
qr_split_0(a_0)
del a_0

n = 2000
a_0 = ht.random.random((n, n), split=0)
a_1 = ht.random.random((n, n), split=1)
qr_split_0(a_0)
qr_split_1(a_1)
del a_0, a_1
del a_1

n = 50
A = ht.random.random((n, n), dtype=ht.float64, split=0)
1,240 changes: 254 additions & 986 deletions heat/core/linalg/qr.py

Large diffs are not rendered by default.

235 changes: 134 additions & 101 deletions heat/core/linalg/tests/test_qr.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,151 @@
import heat as ht
import numpy as np
import os
import torch
import unittest
import warnings
import torch
import numpy as np

from ...tests.test_suites.basic_test import TestCase

if os.environ.get("EXTENDED_TESTS"):
extended_tests = True
warnings.warn("Extended Tests will take roughly 100x longer than the standard tests")
else:
extended_tests = False


@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP")
class TestQR(TestCase):
@unittest.skipIf(not extended_tests, "extended tests")
def test_qr_sp0_ext(self):
st_whole = torch.randn(70, 70, device=self.device.torch_device)
sp = 0
for m in range(50, st_whole.shape[0] + 1, 1):
for n in range(50, st_whole.shape[1] + 1, 1):
for t in range(2, 3):
st = st_whole[:m, :n].clone()
a_comp = ht.array(st, split=0)
a = ht.array(st, split=sp)
qr = a.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))
def test_qr_split1orNone(self):
for split in [1, None]:
for mode in ["reduced", "r"]:
# note that split = 1 can be handeled for arbitrary shapes
for shape in [
(20 * ht.MPI_WORLD.size + 1, 40 * ht.MPI_WORLD.size),
(20 * ht.MPI_WORLD.size, 20 * ht.MPI_WORLD.size),
(40 * ht.MPI_WORLD.size - 1, 20 * ht.MPI_WORLD.size),
]:
for dtype in [ht.float32, ht.float64]:
dtypetol = 1e-3 if dtype == ht.float32 else 1e-6
mat = ht.random.randn(*shape, dtype=dtype, split=split)
qr = ht.linalg.qr(mat, mode=mode)

@unittest.skipIf(not extended_tests, "extended tests")
def test_qr_sp1_ext(self):
st_whole = torch.randn(70, 70, device=self.device.torch_device)
sp = 1
for m in range(50, st_whole.shape[0] + 1, 1):
for n in range(50, st_whole.shape[1] + 1, 1):
for t in range(2, 3):
st = st_whole[:m, :n].clone()
a_comp = ht.array(st, split=0)
a = ht.array(st, split=sp)
qr = a.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))
if mode == "reduced":
self.assertTrue(
ht.allclose(qr.Q @ qr.R, mat, atol=dtypetol, rtol=dtypetol)
)
self.assertIsInstance(qr.Q, ht.DNDarray)

def test_qr(self):
m, n = 20, 40
st = torch.randn(m, n, device=self.device.torch_device, dtype=torch.float)
a_comp = ht.array(st, split=0)
for t in range(2, 3):
for sp in range(2):
a = ht.array(st, split=sp, dtype=torch.float)
qr = a.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose((a_comp - (qr.Q @ qr.R)), 0, rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))
m, n = 40, 40
st1 = torch.randn(m, n, device=self.device.torch_device)
a_comp1 = ht.array(st1, split=0)
for t in range(2, 3):
for sp in range(2):
a1 = ht.array(st1, split=sp)
qr1 = a1.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose((a_comp1 - (qr1.Q @ qr1.R)), 0, rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(qr1.Q.T @ qr1.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr1.Q @ qr1.Q.T, rtol=1e-5, atol=1e-5))
m, n = 40, 20
st2 = torch.randn(m, n, dtype=torch.double, device=self.device.torch_device)
a_comp2 = ht.array(st2, split=0, dtype=ht.double)
for t in range(2, 3):
for sp in range(2):
a2 = ht.array(st2, split=sp)
qr2 = a2.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose(a_comp2, qr2.Q @ qr2.R, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(qr2.Q.T @ qr2.Q, ht.eye(m, dtype=ht.double), rtol=1e-5, atol=1e-5)
)
self.assertTrue(
ht.allclose(ht.eye(m, dtype=ht.double), qr2.Q @ qr2.Q.T, rtol=1e-5, atol=1e-5)
)
# test if Q is orthogonal
self.assertTrue(
ht.allclose(
qr.Q.T @ qr.Q,
ht.eye(qr.Q.shape[1], dtype=dtype),
atol=dtypetol,
rtol=dtypetol,
)
)
# test correct shape of Q
self.assertEqual(qr.Q.shape, (shape[0], min(shape)))
else:
self.assertIsNone(qr.Q)

# test if calc R alone works
a2_0 = ht.array(st2, split=0)
a2_1 = ht.array(st2, split=1)
qr_0 = ht.qr(a2_0, calc_q=False, overwrite_a=True)
self.assertTrue(qr_0.Q is None)
qr_1 = ht.qr(a2_1, calc_q=False, overwrite_a=True)
self.assertTrue(qr_1.Q is None)
# test correct type and shape of R
self.assertIsInstance(qr.R, ht.DNDarray)
self.assertEqual(qr.R.shape, (min(shape), shape[1]))

m, n = 40, 20
st = torch.randn(m, n, device=self.device.torch_device)
a_comp = ht.array(st, split=None)
a = ht.array(st, split=None)
qr = a.qr()
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))
# compare with torch qr, due to different signs we can only compare absolute values
mat_t = mat.resplit_(None).larray
q_t, r_t = torch.linalg.qr(mat_t, mode=mode)
r_ht = qr.R.resplit_(None).larray
self.assertTrue(
torch.allclose(
torch.abs(r_t), torch.abs(r_ht), atol=dtypetol, rtol=dtypetol
)
)
if mode == "reduced":
q_ht = qr.Q.resplit_(None).larray
self.assertTrue(
torch.allclose(
torch.abs(q_t), torch.abs(q_ht), atol=dtypetol, rtol=dtypetol
)
)

# raises
with self.assertRaises(TypeError):
ht.qr(np.zeros((10, 10)))
def test_qr_split0(self):
split = 0
for procs_to_merge in [0, 2, 3]:
for mode in ["reduced", "r"]:
# split = 0 can be handeled only for tall skinny matrices s.t. the local chunks are at least square too
for shape in [(40 * ht.MPI_WORLD.size + 1, 40), (40 * ht.MPI_WORLD.size, 20)]:
for dtype in [ht.float32, ht.float64]:
dtypetol = 1e-3 if dtype == ht.float32 else 1e-6
mat = ht.random.randn(*shape, dtype=dtype, split=split)

qr = ht.linalg.qr(mat, mode=mode, procs_to_merge=procs_to_merge)

if mode == "reduced":
self.assertTrue(
ht.allclose(qr.Q @ qr.R, mat, atol=dtypetol, rtol=dtypetol)
)
self.assertIsInstance(qr.Q, ht.DNDarray)

# test if Q is orthogonal
self.assertTrue(
ht.allclose(
qr.Q.T @ qr.Q,
ht.eye(qr.Q.shape[1], dtype=dtype),
atol=dtypetol,
rtol=dtypetol,
)
)
# test correct shape of Q
self.assertEqual(qr.Q.shape, (shape[0], min(shape)))
else:
self.assertIsNone(qr.Q)

# test correct type and shape of R
self.assertIsInstance(qr.R, ht.DNDarray)
self.assertEqual(qr.R.shape, (min(shape), shape[1]))

# compare with torch qr, due to different signs we can only compare absolute values
mat_t = mat.resplit_(None).larray
q_t, r_t = torch.linalg.qr(mat_t, mode=mode)
r_ht = qr.R.resplit_(None).larray
self.assertTrue(
torch.allclose(
torch.abs(r_t), torch.abs(r_ht), atol=dtypetol, rtol=dtypetol
)
)
if mode == "reduced":
q_ht = qr.Q.resplit_(None).larray
self.assertTrue(
torch.allclose(
torch.abs(q_t), torch.abs(q_ht), atol=dtypetol, rtol=dtypetol
)
)

def test_wronginputs(self):
# test wrong input type
with self.assertRaises(TypeError):
ht.qr(a_comp, tiles_per_proc="ls")
ht.linalg.qr([1, 2, 3])
# test too many input dimensions
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10, 10)))
# wrong data type for mode
with self.assertRaises(TypeError):
ht.qr(a_comp, tiles_per_proc=1, calc_q=30)
ht.linalg.qr(ht.zeros((10, 10)), mode=1)
# test wrong mode (such mode is not available for Torch)
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10)), mode="full")
# test mode that is available for Torch but not for Heat
with self.assertRaises(NotImplementedError):
ht.linalg.qr(ht.zeros((10, 10)), mode="complete")
with self.assertRaises(NotImplementedError):
ht.linalg.qr(ht.zeros((10, 10)), mode="raw")
# wrong dtype for procs_to_merge
with self.assertRaises(TypeError):
ht.qr(a_comp, tiles_per_proc=1, overwrite_a=30)
ht.linalg.qr(ht.zeros((10, 10)), procs_to_merge="abc")
# test wrong procs_to_merge
with self.assertRaises(ValueError):
ht.qr(a_comp, tiles_per_proc=torch.tensor([1, 2, 3]))
ht.linalg.qr(ht.zeros((10, 10)), procs_to_merge=1)
# test wrong shape
with self.assertRaises(ValueError):
ht.qr(ht.zeros((3, 4, 5)))

a_comp.resplit_(0)
with self.assertWarns(Warning):
ht.qr(a_comp, tiles_per_proc=1)
ht.linalg.qr(ht.zeros((10, 10, 10)))
# test wrong dtype
with self.assertRaises(TypeError):
ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32))
# test wrong shape for split=0
if ht.MPI_WORLD.size > 1:
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10), split=0))
2 changes: 0 additions & 2 deletions heat/core/linalg/tests/test_svdtools.py
Original file line number Diff line number Diff line change
@@ -143,10 +143,8 @@ def test_hsvd_rank_part1(self):
self.assertEqual(len(ht.linalg.hsvd_rank(test_matrices[0], 5)), 2)
self.assertEqual(len(ht.linalg.hsvd_rtol(test_matrices[0], 5e-1)), 2)

@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP")
def test_hsvd_rank_part2(self):
# check if hsvd_rank yields correct results for maxrank <= truerank
# this needs to be skipped on AMD because generation of test data relies on QR...
nprocs = MPI.COMM_WORLD.Get_size()
true_rk = max(10, nprocs)
test_matrices_low_rank = [
3 changes: 0 additions & 3 deletions heat/utils/data/tests/test_matrixgallery.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,6 @@ def test_parter(self):
with self.assertRaises(ValueError):
ht.utils.data.matrixgallery.parter(20, split=2, comm=ht.MPI_WORLD)

@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP")
def test_random_orthogonal(self):
with self.assertRaises(RuntimeError):
ht.utils.data.matrixgallery.random_orthogonal(10, 20)
@@ -74,7 +73,6 @@ def test_random_orthogonal(self):
# self.assertTrue(Q_orth_err <= 1e-6)
self.__check_orthogonality(Q)

@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP")
def test_random_known_singularvalues(self):
with self.assertRaises(RuntimeError):
ht.utils.data.matrixgallery.random_known_singularvalues(30, 20, "abc", split=1)
@@ -100,7 +98,6 @@ def test_random_known_singularvalues(self):
A_err = ht.norm(A - U @ ht.diag(S) @ V.T) / ht.norm(A)
self.assertTrue(A_err <= dtype_tol)

@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP")
def test_random_known_rank(self):
with self.assertRaises(RuntimeError):
ht.utils.data.matrixgallery.random_known_rank(30, 20, 25, split=1)

0 comments on commit 67bc254

Please sign in to comment.