-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Refactoring of QR: stabilized Gram-Schmidt for split=1 and TS-QR for …
…split=0 (#1329) * added file `myqr.py` that contains experiment in direction blockwise stabilized GS for QR (split=1) * Update __init__.py * removed `myrandn` since `randn` is fixed meanwhile * ... * replaced previour QR implementation by the new one * ... * refactoring * ... * added some unit tests for the refactored qr with split=1 * ... * resolved some bugs related to calc_q, calc_r etc. * removed option overwrite_a * extended comments in the code * ... * removed strange option full_q (should be addressed somewhere else later) * formatting of warning * adapted API of new qr to follow the one in numpy * first version of split=0 (TS-QR), stil with bugs * fixed some bugs, some others are stil there... * fixed some bugs: a problem was, e.g., that torch's qr returns non-contiguous q and r even if input was contiguous * still errors but I dont find them... * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * found bugs in TS-QR, now we should have a running version TODOS: extend comments (to have maintainable code), check ROCm-support, unittests for wrong inpus/errors * removed skip from some svd_tests on AMD (they were introduced due to missing QR support) added tests for catching wrong inputs in QR slightly extended docs of QR * Update test_matrixgallery.py removed skips for AMD devices * Update linalg.py modified continuous benchmarks for new qr implementation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update linalg.py removed bug in benchmarks * resolved first parts of the comments in the review * updated docs of refactored QR since no batches of matrices are supported * modified docs according to review * removed "from time import sleep", completed dead end sentence in docs --------- Co-authored-by: Hoppe <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Claudia Comito <[email protected]>
1 parent
86fe8a6
commit 67bc254
Showing
5 changed files
with
397 additions
and
1,099 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,118 +1,151 @@ | ||
import heat as ht | ||
import numpy as np | ||
import os | ||
import torch | ||
import unittest | ||
import warnings | ||
import torch | ||
import numpy as np | ||
|
||
from ...tests.test_suites.basic_test import TestCase | ||
|
||
if os.environ.get("EXTENDED_TESTS"): | ||
extended_tests = True | ||
warnings.warn("Extended Tests will take roughly 100x longer than the standard tests") | ||
else: | ||
extended_tests = False | ||
|
||
|
||
@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP") | ||
class TestQR(TestCase): | ||
@unittest.skipIf(not extended_tests, "extended tests") | ||
def test_qr_sp0_ext(self): | ||
st_whole = torch.randn(70, 70, device=self.device.torch_device) | ||
sp = 0 | ||
for m in range(50, st_whole.shape[0] + 1, 1): | ||
for n in range(50, st_whole.shape[1] + 1, 1): | ||
for t in range(2, 3): | ||
st = st_whole[:m, :n].clone() | ||
a_comp = ht.array(st, split=0) | ||
a = ht.array(st, split=sp) | ||
qr = a.qr(tiles_per_proc=t) | ||
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5)) | ||
def test_qr_split1orNone(self): | ||
for split in [1, None]: | ||
for mode in ["reduced", "r"]: | ||
# note that split = 1 can be handeled for arbitrary shapes | ||
for shape in [ | ||
(20 * ht.MPI_WORLD.size + 1, 40 * ht.MPI_WORLD.size), | ||
(20 * ht.MPI_WORLD.size, 20 * ht.MPI_WORLD.size), | ||
(40 * ht.MPI_WORLD.size - 1, 20 * ht.MPI_WORLD.size), | ||
]: | ||
for dtype in [ht.float32, ht.float64]: | ||
dtypetol = 1e-3 if dtype == ht.float32 else 1e-6 | ||
mat = ht.random.randn(*shape, dtype=dtype, split=split) | ||
qr = ht.linalg.qr(mat, mode=mode) | ||
|
||
@unittest.skipIf(not extended_tests, "extended tests") | ||
def test_qr_sp1_ext(self): | ||
st_whole = torch.randn(70, 70, device=self.device.torch_device) | ||
sp = 1 | ||
for m in range(50, st_whole.shape[0] + 1, 1): | ||
for n in range(50, st_whole.shape[1] + 1, 1): | ||
for t in range(2, 3): | ||
st = st_whole[:m, :n].clone() | ||
a_comp = ht.array(st, split=0) | ||
a = ht.array(st, split=sp) | ||
qr = a.qr(tiles_per_proc=t) | ||
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5)) | ||
if mode == "reduced": | ||
self.assertTrue( | ||
ht.allclose(qr.Q @ qr.R, mat, atol=dtypetol, rtol=dtypetol) | ||
) | ||
self.assertIsInstance(qr.Q, ht.DNDarray) | ||
|
||
def test_qr(self): | ||
m, n = 20, 40 | ||
st = torch.randn(m, n, device=self.device.torch_device, dtype=torch.float) | ||
a_comp = ht.array(st, split=0) | ||
for t in range(2, 3): | ||
for sp in range(2): | ||
a = ht.array(st, split=sp, dtype=torch.float) | ||
qr = a.qr(tiles_per_proc=t) | ||
self.assertTrue(ht.allclose((a_comp - (qr.Q @ qr.R)), 0, rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5)) | ||
m, n = 40, 40 | ||
st1 = torch.randn(m, n, device=self.device.torch_device) | ||
a_comp1 = ht.array(st1, split=0) | ||
for t in range(2, 3): | ||
for sp in range(2): | ||
a1 = ht.array(st1, split=sp) | ||
qr1 = a1.qr(tiles_per_proc=t) | ||
self.assertTrue(ht.allclose((a_comp1 - (qr1.Q @ qr1.R)), 0, rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(qr1.Q.T @ qr1.Q, ht.eye(m), rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(ht.eye(m), qr1.Q @ qr1.Q.T, rtol=1e-5, atol=1e-5)) | ||
m, n = 40, 20 | ||
st2 = torch.randn(m, n, dtype=torch.double, device=self.device.torch_device) | ||
a_comp2 = ht.array(st2, split=0, dtype=ht.double) | ||
for t in range(2, 3): | ||
for sp in range(2): | ||
a2 = ht.array(st2, split=sp) | ||
qr2 = a2.qr(tiles_per_proc=t) | ||
self.assertTrue(ht.allclose(a_comp2, qr2.Q @ qr2.R, rtol=1e-5, atol=1e-5)) | ||
self.assertTrue( | ||
ht.allclose(qr2.Q.T @ qr2.Q, ht.eye(m, dtype=ht.double), rtol=1e-5, atol=1e-5) | ||
) | ||
self.assertTrue( | ||
ht.allclose(ht.eye(m, dtype=ht.double), qr2.Q @ qr2.Q.T, rtol=1e-5, atol=1e-5) | ||
) | ||
# test if Q is orthogonal | ||
self.assertTrue( | ||
ht.allclose( | ||
qr.Q.T @ qr.Q, | ||
ht.eye(qr.Q.shape[1], dtype=dtype), | ||
atol=dtypetol, | ||
rtol=dtypetol, | ||
) | ||
) | ||
# test correct shape of Q | ||
self.assertEqual(qr.Q.shape, (shape[0], min(shape))) | ||
else: | ||
self.assertIsNone(qr.Q) | ||
|
||
# test if calc R alone works | ||
a2_0 = ht.array(st2, split=0) | ||
a2_1 = ht.array(st2, split=1) | ||
qr_0 = ht.qr(a2_0, calc_q=False, overwrite_a=True) | ||
self.assertTrue(qr_0.Q is None) | ||
qr_1 = ht.qr(a2_1, calc_q=False, overwrite_a=True) | ||
self.assertTrue(qr_1.Q is None) | ||
# test correct type and shape of R | ||
self.assertIsInstance(qr.R, ht.DNDarray) | ||
self.assertEqual(qr.R.shape, (min(shape), shape[1])) | ||
|
||
m, n = 40, 20 | ||
st = torch.randn(m, n, device=self.device.torch_device) | ||
a_comp = ht.array(st, split=None) | ||
a = ht.array(st, split=None) | ||
qr = a.qr() | ||
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5)) | ||
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5)) | ||
# compare with torch qr, due to different signs we can only compare absolute values | ||
mat_t = mat.resplit_(None).larray | ||
q_t, r_t = torch.linalg.qr(mat_t, mode=mode) | ||
r_ht = qr.R.resplit_(None).larray | ||
self.assertTrue( | ||
torch.allclose( | ||
torch.abs(r_t), torch.abs(r_ht), atol=dtypetol, rtol=dtypetol | ||
) | ||
) | ||
if mode == "reduced": | ||
q_ht = qr.Q.resplit_(None).larray | ||
self.assertTrue( | ||
torch.allclose( | ||
torch.abs(q_t), torch.abs(q_ht), atol=dtypetol, rtol=dtypetol | ||
) | ||
) | ||
|
||
# raises | ||
with self.assertRaises(TypeError): | ||
ht.qr(np.zeros((10, 10))) | ||
def test_qr_split0(self): | ||
split = 0 | ||
for procs_to_merge in [0, 2, 3]: | ||
for mode in ["reduced", "r"]: | ||
# split = 0 can be handeled only for tall skinny matrices s.t. the local chunks are at least square too | ||
for shape in [(40 * ht.MPI_WORLD.size + 1, 40), (40 * ht.MPI_WORLD.size, 20)]: | ||
for dtype in [ht.float32, ht.float64]: | ||
dtypetol = 1e-3 if dtype == ht.float32 else 1e-6 | ||
mat = ht.random.randn(*shape, dtype=dtype, split=split) | ||
|
||
qr = ht.linalg.qr(mat, mode=mode, procs_to_merge=procs_to_merge) | ||
|
||
if mode == "reduced": | ||
self.assertTrue( | ||
ht.allclose(qr.Q @ qr.R, mat, atol=dtypetol, rtol=dtypetol) | ||
) | ||
self.assertIsInstance(qr.Q, ht.DNDarray) | ||
|
||
# test if Q is orthogonal | ||
self.assertTrue( | ||
ht.allclose( | ||
qr.Q.T @ qr.Q, | ||
ht.eye(qr.Q.shape[1], dtype=dtype), | ||
atol=dtypetol, | ||
rtol=dtypetol, | ||
) | ||
) | ||
# test correct shape of Q | ||
self.assertEqual(qr.Q.shape, (shape[0], min(shape))) | ||
else: | ||
self.assertIsNone(qr.Q) | ||
|
||
# test correct type and shape of R | ||
self.assertIsInstance(qr.R, ht.DNDarray) | ||
self.assertEqual(qr.R.shape, (min(shape), shape[1])) | ||
|
||
# compare with torch qr, due to different signs we can only compare absolute values | ||
mat_t = mat.resplit_(None).larray | ||
q_t, r_t = torch.linalg.qr(mat_t, mode=mode) | ||
r_ht = qr.R.resplit_(None).larray | ||
self.assertTrue( | ||
torch.allclose( | ||
torch.abs(r_t), torch.abs(r_ht), atol=dtypetol, rtol=dtypetol | ||
) | ||
) | ||
if mode == "reduced": | ||
q_ht = qr.Q.resplit_(None).larray | ||
self.assertTrue( | ||
torch.allclose( | ||
torch.abs(q_t), torch.abs(q_ht), atol=dtypetol, rtol=dtypetol | ||
) | ||
) | ||
|
||
def test_wronginputs(self): | ||
# test wrong input type | ||
with self.assertRaises(TypeError): | ||
ht.qr(a_comp, tiles_per_proc="ls") | ||
ht.linalg.qr([1, 2, 3]) | ||
# test too many input dimensions | ||
with self.assertRaises(ValueError): | ||
ht.linalg.qr(ht.zeros((10, 10, 10))) | ||
# wrong data type for mode | ||
with self.assertRaises(TypeError): | ||
ht.qr(a_comp, tiles_per_proc=1, calc_q=30) | ||
ht.linalg.qr(ht.zeros((10, 10)), mode=1) | ||
# test wrong mode (such mode is not available for Torch) | ||
with self.assertRaises(ValueError): | ||
ht.linalg.qr(ht.zeros((10, 10)), mode="full") | ||
# test mode that is available for Torch but not for Heat | ||
with self.assertRaises(NotImplementedError): | ||
ht.linalg.qr(ht.zeros((10, 10)), mode="complete") | ||
with self.assertRaises(NotImplementedError): | ||
ht.linalg.qr(ht.zeros((10, 10)), mode="raw") | ||
# wrong dtype for procs_to_merge | ||
with self.assertRaises(TypeError): | ||
ht.qr(a_comp, tiles_per_proc=1, overwrite_a=30) | ||
ht.linalg.qr(ht.zeros((10, 10)), procs_to_merge="abc") | ||
# test wrong procs_to_merge | ||
with self.assertRaises(ValueError): | ||
ht.qr(a_comp, tiles_per_proc=torch.tensor([1, 2, 3])) | ||
ht.linalg.qr(ht.zeros((10, 10)), procs_to_merge=1) | ||
# test wrong shape | ||
with self.assertRaises(ValueError): | ||
ht.qr(ht.zeros((3, 4, 5))) | ||
|
||
a_comp.resplit_(0) | ||
with self.assertWarns(Warning): | ||
ht.qr(a_comp, tiles_per_proc=1) | ||
ht.linalg.qr(ht.zeros((10, 10, 10))) | ||
# test wrong dtype | ||
with self.assertRaises(TypeError): | ||
ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32)) | ||
# test wrong shape for split=0 | ||
if ht.MPI_WORLD.size > 1: | ||
with self.assertRaises(ValueError): | ||
ht.linalg.qr(ht.zeros((10, 10), split=0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters