From 5d2671bff767c96d2be031646d16622020ea7124 Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 09:24:12 -0700 Subject: [PATCH 1/8] add generate_inv_roots option --- gpytorch/lazy/lazy_tensor.py | 30 +++++++++++++------------- gpytorch/test/lazy_tensor_test_case.py | 11 ++++++++-- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/gpytorch/lazy/lazy_tensor.py b/gpytorch/lazy/lazy_tensor.py index cb50c18c2..d6f65c1fd 100644 --- a/gpytorch/lazy/lazy_tensor.py +++ b/gpytorch/lazy/lazy_tensor.py @@ -735,7 +735,7 @@ def add_jitter(self, jitter_val=1e-3): diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device) return self.add_diag(diag) - def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs): + def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs): """ Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g. C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat) @@ -762,8 +762,8 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs If :math:`A` is n x n, then this matrix should be n x k. new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`. If :math:`B` is n x k, then this matrix should be k x k. - generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it - has not been created yet. + generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it has not been created yet. + generate_inv_roots (:obj:`bool`): whether to generate the root inv decomposition of :math:`A` even if it has not been created yet. Returns: :obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns. @@ -809,20 +809,20 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs new_root[..., :m, :n] = E.evaluate() new_root[..., m:, : lower_left.shape[-1]] = lower_left new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root - - if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor): - # make sure these are actually upper triangular - if getattr(E, "upper", False) or getattr(schur_root, "upper", False): - raise NotImplementedError - # in this case we know new_root is triangular as well - new_root = TriangularLazyTensor(new_root) - new_inv_root = new_root.inverse().transpose(-1, -2) - else: - # otherwise we use the pseudo-inverse of Z as new inv root - new_inv_root = stable_pinverse(new_root).transpose(-2, -1) + if generate_inv_roots: + if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor): + # make sure these are actually upper triangular + if getattr(E, "upper", False) or getattr(schur_root, "upper", False): + raise NotImplementedError + # in this case we know new_root is triangular as well + new_root = TriangularLazyTensor(new_root) + new_inv_root = new_root.inverse().transpose(-1, -2) + else: + # otherwise we use the pseudo-inverse of Z as new inv root + new_inv_root = stable_pinverse(new_root).transpose(-2, -1) + add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root))) add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root))) - add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root))) return new_lazy_tensor diff --git a/gpytorch/test/lazy_tensor_test_case.py b/gpytorch/test/lazy_tensor_test_case.py index 2ba8d6ec0..3dca581dc 100644 --- a/gpytorch/test/lazy_tensor_test_case.py +++ b/gpytorch/test/lazy_tensor_test_case.py @@ -11,7 +11,8 @@ import gpytorch from gpytorch.settings import linalg_dtypes from gpytorch.utils.cholesky import CHOLESKY_METHOD - +from gpytorch.utils.memoize import get_from_cache +from gpytorch.utils.errors import CachingError from .base_test_case import BaseTestCase @@ -457,13 +458,19 @@ def test_cat_rows(self): root_rhs = new_lt.root_decomposition().matmul(rhs) self.assertAllClose(root_rhs, concat_rhs, **self.tolerances["root_decomposition"]) + # check that root inv is cached + root_inv = get_from_cache(new_lt, "root_inv_decomposition") # check that the inverse root decomposition is close concat_solve = torch.linalg.solve(concatenated_lt, rhs.unsqueeze(-1)).squeeze(-1) - root_inv_solve = new_lt.root_inv_decomposition().matmul(rhs) + root_inv_solve = root_inv.matmul(rhs) self.assertLess( (root_inv_solve - concat_solve).norm() / concat_solve.norm(), self.tolerances["root_inv_decomposition"]["rtol"], ) + # test generate_inv_roots=False + new_lt = lazy_tensor.cat_rows(new_rows, new_point, generate_inv_roots=False) + with self.assertRaises(CachingError): + get_from_cache(new_lt, "root_inv_decomposition") def test_cholesky(self): lazy_tensor = self.create_lazy_tensor() From 09992e11f3b6680f57ee3fea345ffae1cf895ab6 Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 09:30:47 -0700 Subject: [PATCH 2/8] lint --- gpytorch/lazy/lazy_tensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gpytorch/lazy/lazy_tensor.py b/gpytorch/lazy/lazy_tensor.py index d6f65c1fd..4ecd8fa0a 100644 --- a/gpytorch/lazy/lazy_tensor.py +++ b/gpytorch/lazy/lazy_tensor.py @@ -762,8 +762,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T If :math:`A` is n x n, then this matrix should be n x k. new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`. If :math:`B` is n x k, then this matrix should be k x k. - generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it has not been created yet. - generate_inv_roots (:obj:`bool`): whether to generate the root inv decomposition of :math:`A` even if it has not been created yet. + generate_roots (:obj:`bool`): whether to generate the root + decomposition of :math:`A` even if it has not been created yet. + generate_inv_roots (:obj:`bool`): whether to generate the root inv + decomposition of :math:`A` even if it has not been created yet. Returns: :obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns. From e4f3b3f215ae6f70eb2c7eff2b40a9ea972a2af2 Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 09:37:48 -0700 Subject: [PATCH 3/8] lint imports --- gpytorch/test/lazy_tensor_test_case.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/test/lazy_tensor_test_case.py b/gpytorch/test/lazy_tensor_test_case.py index 3dca581dc..e9c4421c9 100644 --- a/gpytorch/test/lazy_tensor_test_case.py +++ b/gpytorch/test/lazy_tensor_test_case.py @@ -11,8 +11,8 @@ import gpytorch from gpytorch.settings import linalg_dtypes from gpytorch.utils.cholesky import CHOLESKY_METHOD -from gpytorch.utils.memoize import get_from_cache from gpytorch.utils.errors import CachingError +from gpytorch.utils.memoize import get_from_cache from .base_test_case import BaseTestCase From 32924e2834922af7146ff2d0468757e52ab75abc Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 09:42:01 -0700 Subject: [PATCH 4/8] lint imports --- gpytorch/test/lazy_tensor_test_case.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gpytorch/test/lazy_tensor_test_case.py b/gpytorch/test/lazy_tensor_test_case.py index e9c4421c9..adf8f1d3b 100644 --- a/gpytorch/test/lazy_tensor_test_case.py +++ b/gpytorch/test/lazy_tensor_test_case.py @@ -13,6 +13,7 @@ from gpytorch.utils.cholesky import CHOLESKY_METHOD from gpytorch.utils.errors import CachingError from gpytorch.utils.memoize import get_from_cache + from .base_test_case import BaseTestCase From 1cd7c9d65c2a9ddb4cdc7f212ddc65844c219a12 Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 10:45:14 -0700 Subject: [PATCH 5/8] add warning and update variable names --- gpytorch/lazy/lazy_tensor.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/gpytorch/lazy/lazy_tensor.py b/gpytorch/lazy/lazy_tensor.py index 4ecd8fa0a..f477acc91 100644 --- a/gpytorch/lazy/lazy_tensor.py +++ b/gpytorch/lazy/lazy_tensor.py @@ -774,7 +774,12 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T from .cat_lazy_tensor import CatLazyTensor from .root_lazy_tensor import RootLazyTensor from .triangular_lazy_tensor import TriangularLazyTensor - + if not generate_roots and generate_inv_roots: + warnings.warn( + "root_inv_decomposition is only generated when " + "root_decomposition is generated.", + UserWarning, + ) B_, B = cross_mat, lazify(cross_mat) D = lazify(new_mat) batch_shape = B.shape[:-2] @@ -791,13 +796,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one - does_not_have_roots = any( - _is_in_cache_ignore_args(self, key) for key in ("root_inv_decomposition", "root_inv_decomposition") + has_roots = any( + _is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition") ) - if not generate_roots and not does_not_have_roots: + if not generate_roots and not has_roots: return new_lazy_tensor - # Get compomnents for new root Z = [E 0; F G] + # Get components for new root Z = [E 0; F G] E = self.root_decomposition(**root_decomp_kwargs).root # E = L, LL^T = A m, n = E.shape[-2:] R = self.root_inv_decomposition().root.evaluate() # RR^T = A^{-1} (this is fast if L is triangular) @@ -875,10 +880,10 @@ def add_low_rank( # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one - does_not_have_roots = any( + has_roots = any( _is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition") ) - if not generate_roots and not does_not_have_roots: + if not generate_roots and not has_roots: return new_lazy_tensor # we are going to compute the following From 1d684e5b879699f57e7f20061df7f2ddb2b6fa6e Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 18:18:12 -0700 Subject: [PATCH 6/8] black --- gpytorch/lazy/lazy_tensor.py | 430 +++++++++++++++++++++++++++-------- 1 file changed, 337 insertions(+), 93 deletions(-) diff --git a/gpytorch/lazy/lazy_tensor.py b/gpytorch/lazy/lazy_tensor.py index f477acc91..db25fdeb7 100644 --- a/gpytorch/lazy/lazy_tensor.py +++ b/gpytorch/lazy/lazy_tensor.py @@ -19,13 +19,28 @@ from ..functions._matmul import Matmul from ..functions._root_decomposition import RootDecomposition from ..functions._sqrt_inv_matmul import SqrtInvMatmul -from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper +from ..utils.broadcasting import ( + _matmul_broadcast_shape, + _mul_broadcast_shape, + _to_helper, +) from ..utils.cholesky import psd_safe_cholesky from ..utils.deprecation import _deprecate_renamed_methods from ..utils.errors import CachingError -from ..utils.getitem import _compute_getitem_size, _convert_indices_to_tensors, _is_noop_index, _noop_index +from ..utils.getitem import ( + _compute_getitem_size, + _convert_indices_to_tensors, + _is_noop_index, + _noop_index, +) from ..utils.lanczos import _postprocess_lanczos_root_inv_decomp -from ..utils.memoize import _is_in_cache_ignore_all_args, _is_in_cache_ignore_args, add_to_cache, cached, pop_from_cache +from ..utils.memoize import ( + _is_in_cache_ignore_all_args, + _is_in_cache_ignore_args, + add_to_cache, + cached, + pop_from_cache, +) from ..utils.pinverse import stable_pinverse from ..utils.pivoted_cholesky import pivoted_cholesky from ..utils.warnings import NumericalWarning @@ -128,7 +143,9 @@ def _matmul(self, rhs): Returns: :obj:`torch.tensor`: matrix * rhs """ - raise NotImplementedError("The class {} requires a _matmul function!".format(self.__class__.__name__)) + raise NotImplementedError( + "The class {} requires a _matmul function!".format(self.__class__.__name__) + ) @abstractmethod def _size(self): @@ -142,7 +159,9 @@ def _size(self): Returns: :obj:`torch.Size`: The size of the matrix :math:`K` represented by this LazyTensor """ - raise NotImplementedError("The class {} requires a _size function!".format(self.__class__.__name__)) + raise NotImplementedError( + "The class {} requires a _size function!".format(self.__class__.__name__) + ) @abstractmethod def _transpose_nonbatch(self): @@ -155,7 +174,9 @@ def _transpose_nonbatch(self): does some additional work. Calling this method directly is discouraged. """ raise NotImplementedError( - "The class {} requires a _transpose_nonbatch function!".format(self.__class__.__name__) + "The class {} requires a _transpose_nonbatch function!".format( + self.__class__.__name__ + ) ) #### @@ -234,18 +255,32 @@ def _getitem(self, row_index, col_index, *batch_indices): # Normal case: we have to do some processing on either the rows or columns # We will handle this through "interpolation" - row_interp_indices = torch.arange(0, self.size(-2), dtype=torch.long, device=self.device).view(-1, 1) + row_interp_indices = torch.arange( + 0, self.size(-2), dtype=torch.long, device=self.device + ).view(-1, 1) row_interp_indices = row_interp_indices.expand(*self.batch_shape, -1, 1) - row_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(row_interp_indices) + row_interp_values = torch.tensor( + 1.0, dtype=self.dtype, device=self.device + ).expand_as(row_interp_indices) - col_interp_indices = torch.arange(0, self.size(-1), dtype=torch.long, device=self.device).view(-1, 1) + col_interp_indices = torch.arange( + 0, self.size(-1), dtype=torch.long, device=self.device + ).view(-1, 1) col_interp_indices = col_interp_indices.expand(*self.batch_shape, -1, 1) - col_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(col_interp_indices) + col_interp_values = torch.tensor( + 1.0, dtype=self.dtype, device=self.device + ).expand_as(col_interp_indices) # Construct interpolated LazyTensor from . import InterpolatedLazyTensor - res = InterpolatedLazyTensor(self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values) + res = InterpolatedLazyTensor( + self, + row_interp_indices, + row_interp_values, + col_interp_indices, + col_interp_values, + ) return res._getitem(row_index, col_index, *batch_indices) def _unsqueeze_batch(self, dim): @@ -273,9 +308,15 @@ def _expand_batch(self, batch_shape): This method is used internally by the related function :func:`~gpytorch.lazy.LazyTensor.expand`, which does some additional work. Calling this method directly is discouraged. """ - current_shape = torch.Size([1 for _ in range(len(batch_shape) - self.dim() + 2)] + list(self.batch_shape)) + current_shape = torch.Size( + [1 for _ in range(len(batch_shape) - self.dim() + 2)] + + list(self.batch_shape) + ) batch_repeat = torch.Size( - [expand_size // current_size for expand_size, current_size in zip(batch_shape, current_shape)] + [ + expand_size // current_size + for expand_size, current_size in zip(batch_shape, current_shape) + ] ) return self.repeat(*batch_repeat, 1, 1) @@ -297,28 +338,44 @@ def _get_indices(self, row_index, col_index, *batch_indices): Returns: Tensor (size determined by broadcasted shape of indices) of selected values """ - final_shape = _mul_broadcast_shape(*(index.shape for index in batch_indices), row_index.shape, col_index.shape) + final_shape = _mul_broadcast_shape( + *(index.shape for index in batch_indices), row_index.shape, col_index.shape + ) row_index = row_index.expand(final_shape) col_index = col_index.expand(final_shape) batch_indices = tuple(index.expand(final_shape) for index in batch_indices) - base_lazy_tensor = self._getitem(_noop_index, _noop_index, *batch_indices)._expand_batch(final_shape) + base_lazy_tensor = self._getitem( + _noop_index, _noop_index, *batch_indices + )._expand_batch(final_shape) # Create some interoplation indices and values - row_interp_indices = torch.arange(0, self.size(-2), dtype=torch.long, device=self.device) + row_interp_indices = torch.arange( + 0, self.size(-2), dtype=torch.long, device=self.device + ) row_interp_indices = row_interp_indices[row_index].unsqueeze_(-1).unsqueeze_(-1) - row_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(row_interp_indices) + row_interp_values = torch.tensor( + 1.0, dtype=self.dtype, device=self.device + ).expand_as(row_interp_indices) - col_interp_indices = torch.arange(0, self.size(-1), dtype=torch.long, device=self.device) + col_interp_indices = torch.arange( + 0, self.size(-1), dtype=torch.long, device=self.device + ) col_interp_indices = col_interp_indices[col_index].unsqueeze_(-1).unsqueeze_(-1) - col_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(col_interp_indices) + col_interp_values = torch.tensor( + 1.0, dtype=self.dtype, device=self.device + ).expand_as(col_interp_indices) # Construct interpolated LazyTensor from . import InterpolatedLazyTensor res = ( InterpolatedLazyTensor( - base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values + base_lazy_tensor, + row_interp_indices, + row_interp_values, + col_interp_indices, + col_interp_values, ) .evaluate() .squeeze(-2) @@ -353,7 +410,9 @@ def _quad_form_derivative(self, left_vecs, right_vecs): with torch.autograd.enable_grad(): loss = (left_vecs * self._matmul(right_vecs)).sum() loss.requires_grad_(True) - actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True)) + actual_grads = deque( + torch.autograd.grad(loss, args_with_grads, allow_unused=True) + ) # Now make sure that the object we return has one entry for every item in args grads = [] @@ -412,8 +471,12 @@ def _cholesky(self, upper=False): evaluated_kern_mat = self.evaluate_kernel() - if any(isinstance(sub_mat, KeOpsLazyTensor) for sub_mat in evaluated_kern_mat._args): - raise RuntimeError("Cannot run Cholesky with KeOps: it will either be really slow or not work.") + if any( + isinstance(sub_mat, KeOpsLazyTensor) for sub_mat in evaluated_kern_mat._args + ): + raise RuntimeError( + "Cannot run Cholesky with KeOps: it will either be really slow or not work." + ) evaluated_mat = evaluated_kern_mat.evaluate() @@ -435,7 +498,9 @@ def _cholesky_solve(self, rhs, upper: bool = False): Returns: (LazyTensor) Cholesky factor """ - raise NotImplementedError("_cholesky_solve not implemented for the base LazyTensor") + raise NotImplementedError( + "_cholesky_solve not implemented for the base LazyTensor" + ) def _inv_matmul_preconditioner(self): """ @@ -453,7 +518,9 @@ def _inv_matmul_preconditioner(self): if hasattr(self, "_default_preconditioner_cache"): U, S, V = self._default_preconditioner_cache else: - precond_basis_size = min(gpytorch.settings.max_preconditioner_size.value(), self.size(-1)) + precond_basis_size = min( + gpytorch.settings.max_preconditioner_size.value(), self.size(-1) + ) random_basis = torch.randn( self.batch_shape + torch.Size((self.size(-2), precond_basis_size)), device=self.device, @@ -516,9 +583,16 @@ def _mul_matrix(self, other): if isinstance(self, NonLazyTensor) or isinstance(other, NonLazyTensor): return NonLazyTensor(self.evaluate() * other.evaluate()) else: - left_lazy_tensor = self if self._root_decomposition_size() < other._root_decomposition_size() else other + left_lazy_tensor = ( + self + if self._root_decomposition_size() < other._root_decomposition_size() + else other + ) right_lazy_tensor = other if left_lazy_tensor is self else self - return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition()) + return MulLazyTensor( + left_lazy_tensor.root_decomposition(), + right_lazy_tensor.root_decomposition(), + ) def _preconditioner(self): """ @@ -559,7 +633,10 @@ def _prod_batch(self, dim): shape = list(roots.shape) shape[dim] = 1 extra_root = torch.full( - shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2))) + shape, + dtype=self.dtype, + device=self.device, + fill_value=(1.0 / math.sqrt(self.size(-2))), ) roots = torch.cat([roots, extra_root], dim) num_batch += 1 @@ -708,7 +785,9 @@ def add_diag(self, diag): diag_shape = diag.shape if len(diag_shape) == 0: # interpret scalar tensor as constant diag - diag_tensor = ConstantDiagLazyTensor(diag.unsqueeze(-1), diag_shape=self.shape[-1]) + diag_tensor = ConstantDiagLazyTensor( + diag.unsqueeze(-1), diag_shape=self.shape[-1] + ) elif diag_shape[-1] == 1: # interpret single-trailing element as constant diag diag_tensor = ConstantDiagLazyTensor(diag, diag_shape=self.shape[-1]) @@ -735,7 +814,14 @@ def add_jitter(self, jitter_val=1e-3): diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device) return self.add_diag(diag) - def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs): + def cat_rows( + self, + cross_mat, + new_mat, + generate_roots=True, + generate_inv_roots=True, + **root_decomp_kwargs, + ): """ Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g. C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat) @@ -774,6 +860,7 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T from .cat_lazy_tensor import CatLazyTensor from .root_lazy_tensor import RootLazyTensor from .triangular_lazy_tensor import TriangularLazyTensor + if not generate_roots and generate_inv_roots: warnings.warn( "root_inv_decomposition is only generated when " @@ -784,20 +871,30 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T D = lazify(new_mat) batch_shape = B.shape[:-2] if self.ndimension() < cross_mat.ndimension(): - expand_shape = _mul_broadcast_shape(self.shape[:-2], B.shape[:-2]) + self.shape[-2:] + expand_shape = ( + _mul_broadcast_shape(self.shape[:-2], B.shape[:-2]) + self.shape[-2:] + ) A = self.expand(expand_shape) else: A = self # form matrix C = [A B; B^T D], where A = self, B = cross_mat, D = new_mat upper_row = CatLazyTensor(A, B, dim=-2, output_device=A.device) - lower_row = CatLazyTensor(B.transpose(-1, -2), D, dim=-2, output_device=A.device) - new_lazy_tensor = CatLazyTensor(upper_row, lower_row, dim=-1, output_device=A.device) + lower_row = CatLazyTensor( + B.transpose(-1, -2), D, dim=-2, output_device=A.device + ) + new_lazy_tensor = CatLazyTensor( + upper_row, lower_row, dim=-1, output_device=A.device + ) # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one has_roots = any( - _is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition") + _is_in_cache_ignore_args(self, key) + for key in ( + "root_decomposition", + "root_inv_decomposition", + ) ) if not generate_roots and not has_roots: return new_lazy_tensor @@ -805,19 +902,29 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T # Get components for new root Z = [E 0; F G] E = self.root_decomposition(**root_decomp_kwargs).root # E = L, LL^T = A m, n = E.shape[-2:] - R = self.root_inv_decomposition().root.evaluate() # RR^T = A^{-1} (this is fast if L is triangular) + R = ( + self.root_inv_decomposition().root.evaluate() + ) # RR^T = A^{-1} (this is fast if L is triangular) lower_left = B_ @ R # F = BR - schur = D - lower_left.matmul(lower_left.transpose(-2, -1)) # GG^T = new_mat - FF^T - schur_root = lazify(schur).root_decomposition().root.evaluate() # G = (new_mat - FF^T)^{1/2} + schur = D - lower_left.matmul( + lower_left.transpose(-2, -1) + ) # GG^T = new_mat - FF^T + schur_root = ( + lazify(schur).root_decomposition().root.evaluate() + ) # G = (new_mat - FF^T)^{1/2} # Form new root matrix num_fant = schur_root.size(-2) - new_root = torch.zeros(*batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype) + new_root = torch.zeros( + *batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype + ) new_root[..., :m, :n] = E.evaluate() new_root[..., m:, : lower_left.shape[-1]] = lower_left new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root if generate_inv_roots: - if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor): + if isinstance(E, TriangularLazyTensor) and isinstance( + schur_root, TriangularLazyTensor + ): # make sure these are actually upper triangular if getattr(E, "upper", False) or getattr(schur_root, "upper", False): raise NotImplementedError @@ -827,9 +934,15 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T else: # otherwise we use the pseudo-inverse of Z as new inv root new_inv_root = stable_pinverse(new_root).transpose(-2, -1) - add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root))) + add_to_cache( + new_lazy_tensor, + "root_inv_decomposition", + RootLazyTensor(lazify(new_inv_root)), + ) - add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root))) + add_to_cache( + new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root)) + ) return new_lazy_tensor @@ -868,10 +981,13 @@ def add_low_rank( from .triangular_lazy_tensor import TriangularLazyTensor if not isinstance(self, SumLazyTensor): - new_lazy_tensor = self + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))) + new_lazy_tensor = self + lazify( + low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)) + ) else: new_lazy_tensor = SumLazyTensor( - *self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))) + *self.lazy_tensors, + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))), ) # return as a nonlazy tensor if small enough to reduce memory overhead @@ -881,7 +997,8 @@ def add_low_rank( # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one has_roots = any( - _is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition") + _is_in_cache_ignore_args(self, key) + for key in ("root_decomposition", "root_inv_decomposition") ) if not generate_roots and not has_roots: return new_lazy_tensor @@ -890,11 +1007,15 @@ def add_low_rank( # \tilde{A} = A + BB^T = L(I + L^{-1} B B^T L^{-T})L^T # first get LL^T = A - current_root = self.root_decomposition(method=root_decomp_method, **root_decomp_kwargs).root + current_root = self.root_decomposition( + method=root_decomp_method, **root_decomp_kwargs + ).root return_triangular = isinstance(current_root, TriangularLazyTensor) # and MM^T = A^{-1} - current_inv_root = self.root_inv_decomposition(method=root_inv_decomp_method).root.transpose(-1, -2) + current_inv_root = self.root_inv_decomposition( + method=root_inv_decomp_method + ).root.transpose(-1, -2) # compute p = M B and take its SVD pvector = current_inv_root.matmul(low_rank_mat) @@ -903,7 +1024,9 @@ def add_low_rank( U, S, _ = torch.svd(pvector, some=False) # we want the root decomposition of I_r + U S^2 U^T but S is q so we need to pad. - one_padding = torch.ones(*S.shape[:-1], U.shape[-2] - S.shape[-1], device=S.device, dtype=S.dtype) + one_padding = torch.ones( + *S.shape[:-1], U.shape[-2] - S.shape[-1], device=S.device, dtype=S.dtype + ) # the non zero eigenvalues get updated by S^2 + 1, so we take the square root. root_S_plus_identity = (S ** 2 + 1.0) ** 0.5 # pad the nonzero eigenvalues with the ones @@ -921,13 +1044,20 @@ def add_low_rank( updated_root = torch.cat( ( current_root.evaluate(), - torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype), + torch.zeros( + *current_root.shape[:-1], + 1, + device=current_root.device, + dtype=current_root.dtype, + ), ), dim=-1, ) # compute \tilde{S}^{-1} - stacked_inv_root_S = torch.cat((1.0 / root_S_plus_identity, one_padding), dim=-1) + stacked_inv_root_S = torch.cat( + (1.0 / root_S_plus_identity, one_padding), dim=-1 + ) # compute the new inverse inner root: U \tilde{S}^{-1} inner_inv_root = U.matmul(torch.diag_embed(stacked_inv_root_S)) # finally \tilde{L}^{-1} = L^{-1} U \tilde{S}^{-1} @@ -937,8 +1067,12 @@ def add_low_rank( updated_root = TriangularLazyTensor(updated_root) updated_inv_root = TriangularLazyTensor(updated_inv_root) - add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(updated_root)) - add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(updated_inv_root)) + add_to_cache( + new_lazy_tensor, "root_decomposition", RootLazyTensor(updated_root) + ) + add_to_cache( + new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(updated_inv_root) + ) return new_lazy_tensor @@ -976,7 +1110,10 @@ def clone(self): Clones the LazyTensor (creates clones of all underlying tensors) """ args = [arg.clone() if hasattr(arg, "clone") else arg for arg in self._args] - kwargs = {key: val.clone() if hasattr(val, "clone") else val for key, val in self._kwargs.items()} + kwargs = { + key: val.clone() if hasattr(val, "clone") else val + for key, val in self._kwargs.items() + } return self.__class__(*args, **kwargs) def cpu(self): @@ -1060,7 +1197,9 @@ def diag(self): if not self.is_square: raise RuntimeError("Diag works on square matrices (or batches)") - row_col_iter = torch.arange(0, self.matrix_shape[-1], dtype=torch.long, device=self.device) + row_col_iter = torch.arange( + 0, self.matrix_shape[-1], dtype=torch.long, device=self.device + ) return self[..., row_col_iter, row_col_iter] def dim(self): @@ -1179,9 +1318,17 @@ def inv_matmul(self, right_tensor, left_tensor=None): func = InvMatmul if left_tensor is None: - return func.apply(self.representation_tree(), False, right_tensor, *self.representation()) + return func.apply( + self.representation_tree(), False, right_tensor, *self.representation() + ) else: - return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation()) + return func.apply( + self.representation_tree(), + True, + left_tensor, + right_tensor, + *self.representation(), + ) def inv_quad(self, tensor, reduce_inv_quad=True): """ @@ -1212,7 +1359,9 @@ def inv_quad(self, tensor, reduce_inv_quad=True): ) ) - args = (tensor.expand(*result_shape[:-2], *tensor.shape[-2:]),) + self.representation() + args = ( + tensor.expand(*result_shape[:-2], *tensor.shape[-2:]), + ) + self.representation() func = InvQuad.apply inv_quad_term = func(self.representation_tree(), *args) @@ -1234,7 +1383,9 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) - scalar - log determinant """ # Special case: use Cholesky to compute these terms - if settings.fast_computations.log_prob.off() or (self.size(-1) <= settings.max_cholesky_size.value()): + if settings.fast_computations.log_prob.off() or ( + self.size(-1) <= settings.max_cholesky_size.value() + ): from .chol_lazy_tensor import CholLazyTensor from .triangular_lazy_tensor import TriangularLazyTensor @@ -1248,7 +1399,11 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) will_need_cholesky = False if will_need_cholesky: cholesky = CholLazyTensor(TriangularLazyTensor(self.cholesky())) - return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad) + return cholesky.inv_quad_logdet( + inv_quad_rhs=inv_quad_rhs, + logdet=logdet, + reduce_inv_quad=reduce_inv_quad, + ) # Default: use modified batch conjugate gradients to compute these terms # See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165 @@ -1271,7 +1426,10 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) "LazyTensor (size={}) and right-hand-side Tensor (size={}) should have the same number " "of dimensions.".format(self.shape, inv_quad_rhs.shape) ) - elif self.batch_shape != inv_quad_rhs.shape[:-2] or self.shape[-1] != inv_quad_rhs.shape[-2]: + elif ( + self.batch_shape != inv_quad_rhs.shape[:-2] + or self.shape[-1] != inv_quad_rhs.shape[-2] + ): raise RuntimeError( "LazyTensor (size={}) cannot be multiplied with right-hand-side Tensor (size={}).".format( self.shape, inv_quad_rhs.shape @@ -1377,7 +1535,9 @@ def mul(self, other): _mul_broadcast_shape(self.shape, other.shape) except RuntimeError: raise RuntimeError( - "Cannot multiply LazyTensor of size {} by an object of size {}".format(self.shape, other.shape) + "Cannot multiply LazyTensor of size {} by an object of size {}".format( + self.shape, other.shape + ) ) if torch.is_tensor(other): @@ -1425,7 +1585,9 @@ def permute(self, *dims): ) if dims[-2:] != (num_dims - 2, num_dims - 1): - raise ValueError("At the moment, cannot permute the non-batch dimensions of LazyTensors.") + raise ValueError( + "At the moment, cannot permute the non-batch dimensions of LazyTensors." + ) return self._permute_batch(*dims[:-2]) @@ -1459,7 +1621,9 @@ def prod(self, dim=None): >>> # Returns: torch.Tensor([[[2, 4], [0, -2]], [[6, 2], [2, 0]]]) """ if dim is None: - raise ValueError("At the moment, LazyTensor.prod requires a dim argument (got None)") + raise ValueError( + "At the moment, LazyTensor.prod requires a dim argument (got None)" + ) orig_dim = dim if dim < 0: @@ -1507,10 +1671,14 @@ def representation(self): for arg in self._args: if torch.is_tensor(arg): representation.append(arg) - elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LazyTensor? + elif hasattr(arg, "representation") and callable( + arg.representation + ): # Is it a LazyTensor? representation += list(arg.representation()) else: - raise RuntimeError("Representation of a LazyTensor should consist only of Tensors") + raise RuntimeError( + "Representation of a LazyTensor should consist only of Tensors" + ) return tuple(representation) def representation_tree(self): @@ -1640,12 +1808,15 @@ def root_decomposition(self, method: Optional[str] = None): return CholLazyTensor(res) except RuntimeError as e: warnings.warn( - f"Runtime Error when computing Cholesky decomposition: {e}. Using symeig method.", NumericalWarning, + f"Runtime Error when computing Cholesky decomposition: {e}. Using symeig method.", + NumericalWarning, ) method = "symeig" if method == "pivoted_cholesky": - root = pivoted_cholesky(self.evaluate(), max_iter=self._root_decomposition_size()) + root = pivoted_cholesky( + self.evaluate(), max_iter=self._root_decomposition_size() + ) elif method == "symeig": evals, evecs = self.symeig(eigenvectors=True) # TODO: only use non-zero evals (req. dealing w/ batches...) @@ -1665,7 +1836,9 @@ def root_decomposition(self, method: Optional[str] = None): return RootLazyTensor(root) @cached(name="root_inv_decomposition") - def root_inv_decomposition(self, initial_vectors=None, test_vectors=None, method: Optional[str] = None): + def root_inv_decomposition( + self, initial_vectors=None, test_vectors=None, method: Optional[str] = None + ): """ Returns a (usually low-rank) root decomposotion lazy tensor of a PSD matrix. This can be used for sampling from a Gaussian distribution, or for obtaining a @@ -1709,7 +1882,10 @@ def root_inv_decomposition(self, initial_vectors=None, test_vectors=None, method "LazyTensor (size={}) and initial_vectors (size={}) should have the same number " "of dimensions.".format(self.shape, initial_vectors.shape) ) - elif self.batch_shape != initial_vectors.shape[:-2] or self.shape[-1] != initial_vectors.shape[-2]: + elif ( + self.batch_shape != initial_vectors.shape[:-2] + or self.shape[-1] != initial_vectors.shape[-2] + ): raise RuntimeError( "LazyTensor (size={}) cannot be multiplied with initial_vectors (size={}).".format( self.shape, initial_vectors.shape @@ -1718,7 +1894,9 @@ def root_inv_decomposition(self, initial_vectors=None, test_vectors=None, method inv_root = self._root_inv_decomposition(initial_vectors) if initial_vectors is not None and initial_vectors.size(-1) > 1: - inv_root = _postprocess_lanczos_root_inv_decomp(self, inv_root, initial_vectors, test_vectors) + inv_root = _postprocess_lanczos_root_inv_decomp( + self, inv_root, initial_vectors, test_vectors + ) elif method == "symeig": evals, evecs = self.symeig(eigenvectors=True) # TODO: only use non-zero evals (req. dealing w/ batches...) @@ -1771,7 +1949,9 @@ def sqrt_inv_matmul(self, rhs, lhs=None): squeeze = True func = SqrtInvMatmul - sqrt_inv_matmul_res, inv_quad_res = func.apply(self.representation_tree(), rhs, lhs, *self.representation()) + sqrt_inv_matmul_res, inv_quad_res = func.apply( + self.representation_tree(), rhs, lhs, *self.representation() + ) if squeeze: sqrt_inv_matmul_res = sqrt_inv_matmul_res.squeeze(-1) @@ -1825,7 +2005,11 @@ def sum(self, dim=None): elif dim < self.dim(): return self._sum_batch(dim) else: - raise ValueError("Invalid dim ({}) for LazyTensor of size {}".format(orig_dim, self.shape)) + raise ValueError( + "Invalid dim ({}) for LazyTensor of size {}".format( + orig_dim, self.shape + ) + ) def svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: """ @@ -1844,7 +2028,9 @@ def svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: return self._svd() @cached(name="symeig") - def symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTensor"]]: + def symeig( + self, eigenvectors: bool = False + ) -> Tuple[Tensor, Optional["LazyTensor"]]: """ Compute the symmetric eigendecomposition of the lazy tensor. This can be very slow for large tensors. Should be special-cased for tensors with particular @@ -1917,7 +2103,12 @@ def transpose(self, dim1, dim2): dim1 = ndimension + dim1 if dim2 < 0: dim2 = ndimension + dim2 - if dim1 >= ndimension or dim2 >= ndimension or not isinstance(dim1, int) or not isinstance(dim2, int): + if ( + dim1 >= ndimension + or dim2 >= ndimension + or not isinstance(dim1, int) + or not isinstance(dim2, int) + ): raise RuntimeError("Invalid dimension") # Batch case @@ -1936,7 +2127,9 @@ def transpose(self, dim1, dim2): res = self._transpose_nonbatch() else: - raise RuntimeError("Cannot transpose batch dimension with non-batch dimension") + raise RuntimeError( + "Cannot transpose batch dimension with non-batch dimension" + ) return res @@ -1995,7 +2188,11 @@ def zero_mean_mvn_samples(self, num_samples): if settings.ciq_samples.on(): base_samples = torch.randn( - *self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device + *self.batch_shape, + self.size(-1), + num_samples, + dtype=self.dtype, + device=self.device, ) base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous() base_samples = base_samples.unsqueeze(-1) @@ -2015,9 +2212,17 @@ def zero_mean_mvn_samples(self, num_samples): covar_root = self.root_decomposition().root base_samples = torch.randn( - *self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device + *self.batch_shape, + covar_root.size(-1), + num_samples, + dtype=self.dtype, + device=self.device, + ) + samples = ( + covar_root.matmul(base_samples) + .permute(-1, *range(self.dim() - 1)) + .contiguous() ) - samples = covar_root.matmul(base_samples).permute(-1, *range(self.dim() - 1)).contiguous() return samples @@ -2051,8 +2256,16 @@ def __add__(self, other): elif isinstance(other, Tensor): other = lazify(other) shape = _mul_broadcast_shape(self.shape, other.shape) - new_self = self if self.shape[:-2] == shape[:-2] else self._expand_batch(shape[:-2]) - new_other = other if other.shape[:-2] == shape[:-2] else other._expand_batch(shape[:-2]) + new_self = ( + self + if self.shape[:-2] == shape[:-2] + else self._expand_batch(shape[:-2]) + ) + new_other = ( + other + if other.shape[:-2] == shape[:-2] + else other._expand_batch(shape[:-2]) + ) return SumLazyTensor(new_self, new_other) else: return SumLazyTensor(self, other) @@ -2073,7 +2286,9 @@ def __div__(self, other): from .zero_lazy_tensor import ZeroLazyTensor if isinstance(other, ZeroLazyTensor): - raise RuntimeError("Attempted to divide by a ZeroLazyTensor (divison by zero)") + raise RuntimeError( + "Attempted to divide by a ZeroLazyTensor (divison by zero)" + ) return self.mul(1.0 / other) @@ -2086,12 +2301,19 @@ def __getitem__(self, index): # Process the index index = index if isinstance(index, tuple) else (index,) - index = tuple(torch.tensor(idx) if isinstance(idx, list) else idx for idx in index) - index = tuple(idx.item() if torch.is_tensor(idx) and not len(idx.shape) else idx for idx in index) + index = tuple( + torch.tensor(idx) if isinstance(idx, list) else idx for idx in index + ) + index = tuple( + idx.item() if torch.is_tensor(idx) and not len(idx.shape) else idx + for idx in index + ) # Handle the ellipsis # Find the index of the ellipsis - ellipsis_locs = tuple(index for index, item in enumerate(index) if item is Ellipsis) + ellipsis_locs = tuple( + index for index, item in enumerate(index) if item is Ellipsis + ) if settings.debug.on(): if len(ellipsis_locs) > 1: raise RuntimeError( @@ -2101,7 +2323,11 @@ def __getitem__(self, index): if len(ellipsis_locs) == 1: ellipsis_loc = ellipsis_locs[0] num_to_fill_in = ndimension - (len(index) - 1) - index = index[:ellipsis_loc] + tuple(_noop_index for _ in range(num_to_fill_in)) + index[ellipsis_loc + 1 :] + index = ( + index[:ellipsis_loc] + + tuple(_noop_index for _ in range(num_to_fill_in)) + + index[ellipsis_loc + 1 :] + ) # Pad the index with empty indices index = index + tuple(_noop_index for _ in range(ndimension - len(index))) @@ -2110,14 +2336,18 @@ def __getitem__(self, index): *batch_indices, row_index, col_index = index # Helpers to determine what the final shape will be if we're tensor indexed - batch_has_tensor_index = bool(len(batch_indices)) and any(torch.is_tensor(index) for index in batch_indices) + batch_has_tensor_index = bool(len(batch_indices)) and any( + torch.is_tensor(index) for index in batch_indices + ) row_has_tensor_index = torch.is_tensor(row_index) col_has_tensor_index = torch.is_tensor(col_index) # These are the cases where the row and/or column indices will be "absorbed" into other indices row_col_are_absorbed = any( ( - batch_has_tensor_index and (row_has_tensor_index or col_has_tensor_index), - not batch_has_tensor_index and (row_has_tensor_index and col_has_tensor_index), + batch_has_tensor_index + and (row_has_tensor_index or col_has_tensor_index), + not batch_has_tensor_index + and (row_has_tensor_index and col_has_tensor_index), ) ) @@ -2158,7 +2388,9 @@ def __getitem__(self, index): if expected_shape != res.shape: raise RuntimeError( "{}.__getitem__ failed! Expected a final shape of size {}, got {}. This is a bug with GPyTorch, " - "or your custom LazyTensor.".format(self.__class__.__name__, expected_shape, res.shape) + "or your custom LazyTensor.".format( + self.__class__.__name__, expected_shape, res.shape + ) ) # We're done! @@ -2176,16 +2408,22 @@ def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: V = evecs return U, S, V - def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTensor"]]: + def _symeig( + self, eigenvectors: bool = False + ) -> Tuple[Tensor, Optional["LazyTensor"]]: """Method that allows implementing special-cased symeig computation. Should not be called directly""" from gpytorch.lazy.non_lazy_tensor import NonLazyTensor if settings.verbose_linalg.on(): - settings.verbose_linalg.logger.debug(f"Running symeig on a matrix of size {self.shape}.") + settings.verbose_linalg.logger.debug( + f"Running symeig on a matrix of size {self.shape}." + ) # potentially perform decomposition in double precision for numerical stability dtype = self.dtype - evals, evecs = torch.linalg.eigh(self.evaluate().to(dtype=settings._linalg_dtype_symeig.value())) + evals, evecs = torch.linalg.eigh( + self.evaluate().to(dtype=settings._linalg_dtype_symeig.value()) + ) # chop any negative eigenvalues. # TODO: warn if evals are significantly negative evals = evals.clamp_min(0.0).to(dtype=dtype) @@ -2232,9 +2470,15 @@ def delazify(obj): elif isinstance(obj, LazyTensor): return obj.evaluate() else: - raise TypeError("object of class {} cannot be made into a Tensor".format(obj.__class__.__name__)) + raise TypeError( + "object of class {} cannot be made into a Tensor".format( + obj.__class__.__name__ + ) + ) -_deprecate_renamed_methods(LazyTensor, inv_quad_log_det="inv_quad_logdet", log_det="logdet") +_deprecate_renamed_methods( + LazyTensor, inv_quad_log_det="inv_quad_logdet", log_det="logdet" +) __all__ = ["LazyTensor", "delazify"] From 0351a1a52d81192eb3cacfba3999cc905e8c1444 Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 18:24:20 -0700 Subject: [PATCH 7/8] black --- gpytorch/lazy/lazy_tensor.py | 337 +++++++++-------------------------- 1 file changed, 82 insertions(+), 255 deletions(-) diff --git a/gpytorch/lazy/lazy_tensor.py b/gpytorch/lazy/lazy_tensor.py index db25fdeb7..730e66fce 100644 --- a/gpytorch/lazy/lazy_tensor.py +++ b/gpytorch/lazy/lazy_tensor.py @@ -143,9 +143,7 @@ def _matmul(self, rhs): Returns: :obj:`torch.tensor`: matrix * rhs """ - raise NotImplementedError( - "The class {} requires a _matmul function!".format(self.__class__.__name__) - ) + raise NotImplementedError("The class {} requires a _matmul function!".format(self.__class__.__name__)) @abstractmethod def _size(self): @@ -159,9 +157,7 @@ def _size(self): Returns: :obj:`torch.Size`: The size of the matrix :math:`K` represented by this LazyTensor """ - raise NotImplementedError( - "The class {} requires a _size function!".format(self.__class__.__name__) - ) + raise NotImplementedError("The class {} requires a _size function!".format(self.__class__.__name__)) @abstractmethod def _transpose_nonbatch(self): @@ -174,9 +170,7 @@ def _transpose_nonbatch(self): does some additional work. Calling this method directly is discouraged. """ raise NotImplementedError( - "The class {} requires a _transpose_nonbatch function!".format( - self.__class__.__name__ - ) + "The class {} requires a _transpose_nonbatch function!".format(self.__class__.__name__) ) #### @@ -255,21 +249,13 @@ def _getitem(self, row_index, col_index, *batch_indices): # Normal case: we have to do some processing on either the rows or columns # We will handle this through "interpolation" - row_interp_indices = torch.arange( - 0, self.size(-2), dtype=torch.long, device=self.device - ).view(-1, 1) + row_interp_indices = torch.arange(0, self.size(-2), dtype=torch.long, device=self.device).view(-1, 1) row_interp_indices = row_interp_indices.expand(*self.batch_shape, -1, 1) - row_interp_values = torch.tensor( - 1.0, dtype=self.dtype, device=self.device - ).expand_as(row_interp_indices) + row_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(row_interp_indices) - col_interp_indices = torch.arange( - 0, self.size(-1), dtype=torch.long, device=self.device - ).view(-1, 1) + col_interp_indices = torch.arange(0, self.size(-1), dtype=torch.long, device=self.device).view(-1, 1) col_interp_indices = col_interp_indices.expand(*self.batch_shape, -1, 1) - col_interp_values = torch.tensor( - 1.0, dtype=self.dtype, device=self.device - ).expand_as(col_interp_indices) + col_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(col_interp_indices) # Construct interpolated LazyTensor from . import InterpolatedLazyTensor @@ -308,15 +294,9 @@ def _expand_batch(self, batch_shape): This method is used internally by the related function :func:`~gpytorch.lazy.LazyTensor.expand`, which does some additional work. Calling this method directly is discouraged. """ - current_shape = torch.Size( - [1 for _ in range(len(batch_shape) - self.dim() + 2)] - + list(self.batch_shape) - ) + current_shape = torch.Size([1 for _ in range(len(batch_shape) - self.dim() + 2)] + list(self.batch_shape)) batch_repeat = torch.Size( - [ - expand_size // current_size - for expand_size, current_size in zip(batch_shape, current_shape) - ] + [expand_size // current_size for expand_size, current_size in zip(batch_shape, current_shape)] ) return self.repeat(*batch_repeat, 1, 1) @@ -338,33 +318,21 @@ def _get_indices(self, row_index, col_index, *batch_indices): Returns: Tensor (size determined by broadcasted shape of indices) of selected values """ - final_shape = _mul_broadcast_shape( - *(index.shape for index in batch_indices), row_index.shape, col_index.shape - ) + final_shape = _mul_broadcast_shape(*(index.shape for index in batch_indices), row_index.shape, col_index.shape) row_index = row_index.expand(final_shape) col_index = col_index.expand(final_shape) batch_indices = tuple(index.expand(final_shape) for index in batch_indices) - base_lazy_tensor = self._getitem( - _noop_index, _noop_index, *batch_indices - )._expand_batch(final_shape) + base_lazy_tensor = self._getitem(_noop_index, _noop_index, *batch_indices)._expand_batch(final_shape) # Create some interoplation indices and values - row_interp_indices = torch.arange( - 0, self.size(-2), dtype=torch.long, device=self.device - ) + row_interp_indices = torch.arange(0, self.size(-2), dtype=torch.long, device=self.device) row_interp_indices = row_interp_indices[row_index].unsqueeze_(-1).unsqueeze_(-1) - row_interp_values = torch.tensor( - 1.0, dtype=self.dtype, device=self.device - ).expand_as(row_interp_indices) + row_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(row_interp_indices) - col_interp_indices = torch.arange( - 0, self.size(-1), dtype=torch.long, device=self.device - ) + col_interp_indices = torch.arange(0, self.size(-1), dtype=torch.long, device=self.device) col_interp_indices = col_interp_indices[col_index].unsqueeze_(-1).unsqueeze_(-1) - col_interp_values = torch.tensor( - 1.0, dtype=self.dtype, device=self.device - ).expand_as(col_interp_indices) + col_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(col_interp_indices) # Construct interpolated LazyTensor from . import InterpolatedLazyTensor @@ -410,9 +378,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs): with torch.autograd.enable_grad(): loss = (left_vecs * self._matmul(right_vecs)).sum() loss.requires_grad_(True) - actual_grads = deque( - torch.autograd.grad(loss, args_with_grads, allow_unused=True) - ) + actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True)) # Now make sure that the object we return has one entry for every item in args grads = [] @@ -471,12 +437,8 @@ def _cholesky(self, upper=False): evaluated_kern_mat = self.evaluate_kernel() - if any( - isinstance(sub_mat, KeOpsLazyTensor) for sub_mat in evaluated_kern_mat._args - ): - raise RuntimeError( - "Cannot run Cholesky with KeOps: it will either be really slow or not work." - ) + if any(isinstance(sub_mat, KeOpsLazyTensor) for sub_mat in evaluated_kern_mat._args): + raise RuntimeError("Cannot run Cholesky with KeOps: it will either be really slow or not work.") evaluated_mat = evaluated_kern_mat.evaluate() @@ -498,9 +460,7 @@ def _cholesky_solve(self, rhs, upper: bool = False): Returns: (LazyTensor) Cholesky factor """ - raise NotImplementedError( - "_cholesky_solve not implemented for the base LazyTensor" - ) + raise NotImplementedError("_cholesky_solve not implemented for the base LazyTensor") def _inv_matmul_preconditioner(self): """ @@ -518,9 +478,7 @@ def _inv_matmul_preconditioner(self): if hasattr(self, "_default_preconditioner_cache"): U, S, V = self._default_preconditioner_cache else: - precond_basis_size = min( - gpytorch.settings.max_preconditioner_size.value(), self.size(-1) - ) + precond_basis_size = min(gpytorch.settings.max_preconditioner_size.value(), self.size(-1)) random_basis = torch.randn( self.batch_shape + torch.Size((self.size(-2), precond_basis_size)), device=self.device, @@ -583,11 +541,7 @@ def _mul_matrix(self, other): if isinstance(self, NonLazyTensor) or isinstance(other, NonLazyTensor): return NonLazyTensor(self.evaluate() * other.evaluate()) else: - left_lazy_tensor = ( - self - if self._root_decomposition_size() < other._root_decomposition_size() - else other - ) + left_lazy_tensor = self if self._root_decomposition_size() < other._root_decomposition_size() else other right_lazy_tensor = other if left_lazy_tensor is self else self return MulLazyTensor( left_lazy_tensor.root_decomposition(), @@ -785,9 +739,7 @@ def add_diag(self, diag): diag_shape = diag.shape if len(diag_shape) == 0: # interpret scalar tensor as constant diag - diag_tensor = ConstantDiagLazyTensor( - diag.unsqueeze(-1), diag_shape=self.shape[-1] - ) + diag_tensor = ConstantDiagLazyTensor(diag.unsqueeze(-1), diag_shape=self.shape[-1]) elif diag_shape[-1] == 1: # interpret single-trailing element as constant diag diag_tensor = ConstantDiagLazyTensor(diag, diag_shape=self.shape[-1]) @@ -863,29 +815,22 @@ def cat_rows( if not generate_roots and generate_inv_roots: warnings.warn( - "root_inv_decomposition is only generated when " - "root_decomposition is generated.", + "root_inv_decomposition is only generated when " "root_decomposition is generated.", UserWarning, ) B_, B = cross_mat, lazify(cross_mat) D = lazify(new_mat) batch_shape = B.shape[:-2] if self.ndimension() < cross_mat.ndimension(): - expand_shape = ( - _mul_broadcast_shape(self.shape[:-2], B.shape[:-2]) + self.shape[-2:] - ) + expand_shape = _mul_broadcast_shape(self.shape[:-2], B.shape[:-2]) + self.shape[-2:] A = self.expand(expand_shape) else: A = self # form matrix C = [A B; B^T D], where A = self, B = cross_mat, D = new_mat upper_row = CatLazyTensor(A, B, dim=-2, output_device=A.device) - lower_row = CatLazyTensor( - B.transpose(-1, -2), D, dim=-2, output_device=A.device - ) - new_lazy_tensor = CatLazyTensor( - upper_row, lower_row, dim=-1, output_device=A.device - ) + lower_row = CatLazyTensor(B.transpose(-1, -2), D, dim=-2, output_device=A.device) + new_lazy_tensor = CatLazyTensor(upper_row, lower_row, dim=-1, output_device=A.device) # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one @@ -902,29 +847,19 @@ def cat_rows( # Get components for new root Z = [E 0; F G] E = self.root_decomposition(**root_decomp_kwargs).root # E = L, LL^T = A m, n = E.shape[-2:] - R = ( - self.root_inv_decomposition().root.evaluate() - ) # RR^T = A^{-1} (this is fast if L is triangular) + R = self.root_inv_decomposition().root.evaluate() # RR^T = A^{-1} (this is fast if L is triangular) lower_left = B_ @ R # F = BR - schur = D - lower_left.matmul( - lower_left.transpose(-2, -1) - ) # GG^T = new_mat - FF^T - schur_root = ( - lazify(schur).root_decomposition().root.evaluate() - ) # G = (new_mat - FF^T)^{1/2} + schur = D - lower_left.matmul(lower_left.transpose(-2, -1)) # GG^T = new_mat - FF^T + schur_root = lazify(schur).root_decomposition().root.evaluate() # G = (new_mat - FF^T)^{1/2} # Form new root matrix num_fant = schur_root.size(-2) - new_root = torch.zeros( - *batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype - ) + new_root = torch.zeros(*batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype) new_root[..., :m, :n] = E.evaluate() new_root[..., m:, : lower_left.shape[-1]] = lower_left new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root if generate_inv_roots: - if isinstance(E, TriangularLazyTensor) and isinstance( - schur_root, TriangularLazyTensor - ): + if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor): # make sure these are actually upper triangular if getattr(E, "upper", False) or getattr(schur_root, "upper", False): raise NotImplementedError @@ -940,9 +875,7 @@ def cat_rows( RootLazyTensor(lazify(new_inv_root)), ) - add_to_cache( - new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root)) - ) + add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root))) return new_lazy_tensor @@ -981,9 +914,7 @@ def add_low_rank( from .triangular_lazy_tensor import TriangularLazyTensor if not isinstance(self, SumLazyTensor): - new_lazy_tensor = self + lazify( - low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)) - ) + new_lazy_tensor = self + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))) else: new_lazy_tensor = SumLazyTensor( *self.lazy_tensors, @@ -996,10 +927,7 @@ def add_low_rank( # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one - has_roots = any( - _is_in_cache_ignore_args(self, key) - for key in ("root_decomposition", "root_inv_decomposition") - ) + has_roots = any(_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition")) if not generate_roots and not has_roots: return new_lazy_tensor @@ -1007,15 +935,11 @@ def add_low_rank( # \tilde{A} = A + BB^T = L(I + L^{-1} B B^T L^{-T})L^T # first get LL^T = A - current_root = self.root_decomposition( - method=root_decomp_method, **root_decomp_kwargs - ).root + current_root = self.root_decomposition(method=root_decomp_method, **root_decomp_kwargs).root return_triangular = isinstance(current_root, TriangularLazyTensor) # and MM^T = A^{-1} - current_inv_root = self.root_inv_decomposition( - method=root_inv_decomp_method - ).root.transpose(-1, -2) + current_inv_root = self.root_inv_decomposition(method=root_inv_decomp_method).root.transpose(-1, -2) # compute p = M B and take its SVD pvector = current_inv_root.matmul(low_rank_mat) @@ -1024,9 +948,7 @@ def add_low_rank( U, S, _ = torch.svd(pvector, some=False) # we want the root decomposition of I_r + U S^2 U^T but S is q so we need to pad. - one_padding = torch.ones( - *S.shape[:-1], U.shape[-2] - S.shape[-1], device=S.device, dtype=S.dtype - ) + one_padding = torch.ones(*S.shape[:-1], U.shape[-2] - S.shape[-1], device=S.device, dtype=S.dtype) # the non zero eigenvalues get updated by S^2 + 1, so we take the square root. root_S_plus_identity = (S ** 2 + 1.0) ** 0.5 # pad the nonzero eigenvalues with the ones @@ -1055,9 +977,7 @@ def add_low_rank( ) # compute \tilde{S}^{-1} - stacked_inv_root_S = torch.cat( - (1.0 / root_S_plus_identity, one_padding), dim=-1 - ) + stacked_inv_root_S = torch.cat((1.0 / root_S_plus_identity, one_padding), dim=-1) # compute the new inverse inner root: U \tilde{S}^{-1} inner_inv_root = U.matmul(torch.diag_embed(stacked_inv_root_S)) # finally \tilde{L}^{-1} = L^{-1} U \tilde{S}^{-1} @@ -1067,12 +987,8 @@ def add_low_rank( updated_root = TriangularLazyTensor(updated_root) updated_inv_root = TriangularLazyTensor(updated_inv_root) - add_to_cache( - new_lazy_tensor, "root_decomposition", RootLazyTensor(updated_root) - ) - add_to_cache( - new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(updated_inv_root) - ) + add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(updated_root)) + add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(updated_inv_root)) return new_lazy_tensor @@ -1110,10 +1026,7 @@ def clone(self): Clones the LazyTensor (creates clones of all underlying tensors) """ args = [arg.clone() if hasattr(arg, "clone") else arg for arg in self._args] - kwargs = { - key: val.clone() if hasattr(val, "clone") else val - for key, val in self._kwargs.items() - } + kwargs = {key: val.clone() if hasattr(val, "clone") else val for key, val in self._kwargs.items()} return self.__class__(*args, **kwargs) def cpu(self): @@ -1197,9 +1110,7 @@ def diag(self): if not self.is_square: raise RuntimeError("Diag works on square matrices (or batches)") - row_col_iter = torch.arange( - 0, self.matrix_shape[-1], dtype=torch.long, device=self.device - ) + row_col_iter = torch.arange(0, self.matrix_shape[-1], dtype=torch.long, device=self.device) return self[..., row_col_iter, row_col_iter] def dim(self): @@ -1318,9 +1229,7 @@ def inv_matmul(self, right_tensor, left_tensor=None): func = InvMatmul if left_tensor is None: - return func.apply( - self.representation_tree(), False, right_tensor, *self.representation() - ) + return func.apply(self.representation_tree(), False, right_tensor, *self.representation()) else: return func.apply( self.representation_tree(), @@ -1359,9 +1268,7 @@ def inv_quad(self, tensor, reduce_inv_quad=True): ) ) - args = ( - tensor.expand(*result_shape[:-2], *tensor.shape[-2:]), - ) + self.representation() + args = (tensor.expand(*result_shape[:-2], *tensor.shape[-2:]),) + self.representation() func = InvQuad.apply inv_quad_term = func(self.representation_tree(), *args) @@ -1383,9 +1290,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) - scalar - log determinant """ # Special case: use Cholesky to compute these terms - if settings.fast_computations.log_prob.off() or ( - self.size(-1) <= settings.max_cholesky_size.value() - ): + if settings.fast_computations.log_prob.off() or (self.size(-1) <= settings.max_cholesky_size.value()): from .chol_lazy_tensor import CholLazyTensor from .triangular_lazy_tensor import TriangularLazyTensor @@ -1426,10 +1331,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) "LazyTensor (size={}) and right-hand-side Tensor (size={}) should have the same number " "of dimensions.".format(self.shape, inv_quad_rhs.shape) ) - elif ( - self.batch_shape != inv_quad_rhs.shape[:-2] - or self.shape[-1] != inv_quad_rhs.shape[-2] - ): + elif self.batch_shape != inv_quad_rhs.shape[:-2] or self.shape[-1] != inv_quad_rhs.shape[-2]: raise RuntimeError( "LazyTensor (size={}) cannot be multiplied with right-hand-side Tensor (size={}).".format( self.shape, inv_quad_rhs.shape @@ -1535,9 +1437,7 @@ def mul(self, other): _mul_broadcast_shape(self.shape, other.shape) except RuntimeError: raise RuntimeError( - "Cannot multiply LazyTensor of size {} by an object of size {}".format( - self.shape, other.shape - ) + "Cannot multiply LazyTensor of size {} by an object of size {}".format(self.shape, other.shape) ) if torch.is_tensor(other): @@ -1585,9 +1485,7 @@ def permute(self, *dims): ) if dims[-2:] != (num_dims - 2, num_dims - 1): - raise ValueError( - "At the moment, cannot permute the non-batch dimensions of LazyTensors." - ) + raise ValueError("At the moment, cannot permute the non-batch dimensions of LazyTensors.") return self._permute_batch(*dims[:-2]) @@ -1621,9 +1519,7 @@ def prod(self, dim=None): >>> # Returns: torch.Tensor([[[2, 4], [0, -2]], [[6, 2], [2, 0]]]) """ if dim is None: - raise ValueError( - "At the moment, LazyTensor.prod requires a dim argument (got None)" - ) + raise ValueError("At the moment, LazyTensor.prod requires a dim argument (got None)") orig_dim = dim if dim < 0: @@ -1671,14 +1567,10 @@ def representation(self): for arg in self._args: if torch.is_tensor(arg): representation.append(arg) - elif hasattr(arg, "representation") and callable( - arg.representation - ): # Is it a LazyTensor? + elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LazyTensor? representation += list(arg.representation()) else: - raise RuntimeError( - "Representation of a LazyTensor should consist only of Tensors" - ) + raise RuntimeError("Representation of a LazyTensor should consist only of Tensors") return tuple(representation) def representation_tree(self): @@ -1814,9 +1706,7 @@ def root_decomposition(self, method: Optional[str] = None): method = "symeig" if method == "pivoted_cholesky": - root = pivoted_cholesky( - self.evaluate(), max_iter=self._root_decomposition_size() - ) + root = pivoted_cholesky(self.evaluate(), max_iter=self._root_decomposition_size()) elif method == "symeig": evals, evecs = self.symeig(eigenvectors=True) # TODO: only use non-zero evals (req. dealing w/ batches...) @@ -1836,9 +1726,7 @@ def root_decomposition(self, method: Optional[str] = None): return RootLazyTensor(root) @cached(name="root_inv_decomposition") - def root_inv_decomposition( - self, initial_vectors=None, test_vectors=None, method: Optional[str] = None - ): + def root_inv_decomposition(self, initial_vectors=None, test_vectors=None, method: Optional[str] = None): """ Returns a (usually low-rank) root decomposotion lazy tensor of a PSD matrix. This can be used for sampling from a Gaussian distribution, or for obtaining a @@ -1882,10 +1770,7 @@ def root_inv_decomposition( "LazyTensor (size={}) and initial_vectors (size={}) should have the same number " "of dimensions.".format(self.shape, initial_vectors.shape) ) - elif ( - self.batch_shape != initial_vectors.shape[:-2] - or self.shape[-1] != initial_vectors.shape[-2] - ): + elif self.batch_shape != initial_vectors.shape[:-2] or self.shape[-1] != initial_vectors.shape[-2]: raise RuntimeError( "LazyTensor (size={}) cannot be multiplied with initial_vectors (size={}).".format( self.shape, initial_vectors.shape @@ -1894,9 +1779,7 @@ def root_inv_decomposition( inv_root = self._root_inv_decomposition(initial_vectors) if initial_vectors is not None and initial_vectors.size(-1) > 1: - inv_root = _postprocess_lanczos_root_inv_decomp( - self, inv_root, initial_vectors, test_vectors - ) + inv_root = _postprocess_lanczos_root_inv_decomp(self, inv_root, initial_vectors, test_vectors) elif method == "symeig": evals, evecs = self.symeig(eigenvectors=True) # TODO: only use non-zero evals (req. dealing w/ batches...) @@ -1949,9 +1832,7 @@ def sqrt_inv_matmul(self, rhs, lhs=None): squeeze = True func = SqrtInvMatmul - sqrt_inv_matmul_res, inv_quad_res = func.apply( - self.representation_tree(), rhs, lhs, *self.representation() - ) + sqrt_inv_matmul_res, inv_quad_res = func.apply(self.representation_tree(), rhs, lhs, *self.representation()) if squeeze: sqrt_inv_matmul_res = sqrt_inv_matmul_res.squeeze(-1) @@ -2005,11 +1886,7 @@ def sum(self, dim=None): elif dim < self.dim(): return self._sum_batch(dim) else: - raise ValueError( - "Invalid dim ({}) for LazyTensor of size {}".format( - orig_dim, self.shape - ) - ) + raise ValueError("Invalid dim ({}) for LazyTensor of size {}".format(orig_dim, self.shape)) def svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: """ @@ -2028,9 +1905,7 @@ def svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: return self._svd() @cached(name="symeig") - def symeig( - self, eigenvectors: bool = False - ) -> Tuple[Tensor, Optional["LazyTensor"]]: + def symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTensor"]]: """ Compute the symmetric eigendecomposition of the lazy tensor. This can be very slow for large tensors. Should be special-cased for tensors with particular @@ -2103,12 +1978,7 @@ def transpose(self, dim1, dim2): dim1 = ndimension + dim1 if dim2 < 0: dim2 = ndimension + dim2 - if ( - dim1 >= ndimension - or dim2 >= ndimension - or not isinstance(dim1, int) - or not isinstance(dim2, int) - ): + if dim1 >= ndimension or dim2 >= ndimension or not isinstance(dim1, int) or not isinstance(dim2, int): raise RuntimeError("Invalid dimension") # Batch case @@ -2127,9 +1997,7 @@ def transpose(self, dim1, dim2): res = self._transpose_nonbatch() else: - raise RuntimeError( - "Cannot transpose batch dimension with non-batch dimension" - ) + raise RuntimeError("Cannot transpose batch dimension with non-batch dimension") return res @@ -2218,11 +2086,7 @@ def zero_mean_mvn_samples(self, num_samples): dtype=self.dtype, device=self.device, ) - samples = ( - covar_root.matmul(base_samples) - .permute(-1, *range(self.dim() - 1)) - .contiguous() - ) + samples = covar_root.matmul(base_samples).permute(-1, *range(self.dim() - 1)).contiguous() return samples @@ -2256,16 +2120,8 @@ def __add__(self, other): elif isinstance(other, Tensor): other = lazify(other) shape = _mul_broadcast_shape(self.shape, other.shape) - new_self = ( - self - if self.shape[:-2] == shape[:-2] - else self._expand_batch(shape[:-2]) - ) - new_other = ( - other - if other.shape[:-2] == shape[:-2] - else other._expand_batch(shape[:-2]) - ) + new_self = self if self.shape[:-2] == shape[:-2] else self._expand_batch(shape[:-2]) + new_other = other if other.shape[:-2] == shape[:-2] else other._expand_batch(shape[:-2]) return SumLazyTensor(new_self, new_other) else: return SumLazyTensor(self, other) @@ -2286,9 +2142,7 @@ def __div__(self, other): from .zero_lazy_tensor import ZeroLazyTensor if isinstance(other, ZeroLazyTensor): - raise RuntimeError( - "Attempted to divide by a ZeroLazyTensor (divison by zero)" - ) + raise RuntimeError("Attempted to divide by a ZeroLazyTensor (divison by zero)") return self.mul(1.0 / other) @@ -2301,19 +2155,12 @@ def __getitem__(self, index): # Process the index index = index if isinstance(index, tuple) else (index,) - index = tuple( - torch.tensor(idx) if isinstance(idx, list) else idx for idx in index - ) - index = tuple( - idx.item() if torch.is_tensor(idx) and not len(idx.shape) else idx - for idx in index - ) + index = tuple(torch.tensor(idx) if isinstance(idx, list) else idx for idx in index) + index = tuple(idx.item() if torch.is_tensor(idx) and not len(idx.shape) else idx for idx in index) # Handle the ellipsis # Find the index of the ellipsis - ellipsis_locs = tuple( - index for index, item in enumerate(index) if item is Ellipsis - ) + ellipsis_locs = tuple(index for index, item in enumerate(index) if item is Ellipsis) if settings.debug.on(): if len(ellipsis_locs) > 1: raise RuntimeError( @@ -2323,11 +2170,7 @@ def __getitem__(self, index): if len(ellipsis_locs) == 1: ellipsis_loc = ellipsis_locs[0] num_to_fill_in = ndimension - (len(index) - 1) - index = ( - index[:ellipsis_loc] - + tuple(_noop_index for _ in range(num_to_fill_in)) - + index[ellipsis_loc + 1 :] - ) + index = index[:ellipsis_loc] + tuple(_noop_index for _ in range(num_to_fill_in)) + index[ellipsis_loc + 1 :] # Pad the index with empty indices index = index + tuple(_noop_index for _ in range(ndimension - len(index))) @@ -2336,18 +2179,14 @@ def __getitem__(self, index): *batch_indices, row_index, col_index = index # Helpers to determine what the final shape will be if we're tensor indexed - batch_has_tensor_index = bool(len(batch_indices)) and any( - torch.is_tensor(index) for index in batch_indices - ) + batch_has_tensor_index = bool(len(batch_indices)) and any(torch.is_tensor(index) for index in batch_indices) row_has_tensor_index = torch.is_tensor(row_index) col_has_tensor_index = torch.is_tensor(col_index) # These are the cases where the row and/or column indices will be "absorbed" into other indices row_col_are_absorbed = any( ( - batch_has_tensor_index - and (row_has_tensor_index or col_has_tensor_index), - not batch_has_tensor_index - and (row_has_tensor_index and col_has_tensor_index), + batch_has_tensor_index and (row_has_tensor_index or col_has_tensor_index), + not batch_has_tensor_index and (row_has_tensor_index and col_has_tensor_index), ) ) @@ -2366,9 +2205,11 @@ def __getitem__(self, index): # Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices if row_col_are_absorbed: # Convert all indices into tensor indices - (*batch_indices, row_index, col_index,) = _convert_indices_to_tensors( - self, (*batch_indices, row_index, col_index) - ) + ( + *batch_indices, + row_index, + col_index, + ) = _convert_indices_to_tensors(self, (*batch_indices, row_index, col_index)) res = self._get_indices(row_index, col_index, *batch_indices) else: res = self._getitem(row_index, col_index, *batch_indices) @@ -2388,9 +2229,7 @@ def __getitem__(self, index): if expected_shape != res.shape: raise RuntimeError( "{}.__getitem__ failed! Expected a final shape of size {}, got {}. This is a bug with GPyTorch, " - "or your custom LazyTensor.".format( - self.__class__.__name__, expected_shape, res.shape - ) + "or your custom LazyTensor.".format(self.__class__.__name__, expected_shape, res.shape) ) # We're done! @@ -2408,22 +2247,16 @@ def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: V = evecs return U, S, V - def _symeig( - self, eigenvectors: bool = False - ) -> Tuple[Tensor, Optional["LazyTensor"]]: + def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTensor"]]: """Method that allows implementing special-cased symeig computation. Should not be called directly""" from gpytorch.lazy.non_lazy_tensor import NonLazyTensor if settings.verbose_linalg.on(): - settings.verbose_linalg.logger.debug( - f"Running symeig on a matrix of size {self.shape}." - ) + settings.verbose_linalg.logger.debug(f"Running symeig on a matrix of size {self.shape}.") # potentially perform decomposition in double precision for numerical stability dtype = self.dtype - evals, evecs = torch.linalg.eigh( - self.evaluate().to(dtype=settings._linalg_dtype_symeig.value()) - ) + evals, evecs = torch.linalg.eigh(self.evaluate().to(dtype=settings._linalg_dtype_symeig.value())) # chop any negative eigenvalues. # TODO: warn if evals are significantly negative evals = evals.clamp_min(0.0).to(dtype=dtype) @@ -2470,15 +2303,9 @@ def delazify(obj): elif isinstance(obj, LazyTensor): return obj.evaluate() else: - raise TypeError( - "object of class {} cannot be made into a Tensor".format( - obj.__class__.__name__ - ) - ) + raise TypeError("object of class {} cannot be made into a Tensor".format(obj.__class__.__name__)) -_deprecate_renamed_methods( - LazyTensor, inv_quad_log_det="inv_quad_logdet", log_det="logdet" -) +_deprecate_renamed_methods(LazyTensor, inv_quad_log_det="inv_quad_logdet", log_det="logdet") __all__ = ["LazyTensor", "delazify"] From e36ca9b8441ee254a401db11c16eb990901d2d8b Mon Sep 17 00:00:00 2001 From: Ax Website Deployment Script Date: Tue, 14 Sep 2021 18:38:57 -0700 Subject: [PATCH 8/8] pre commit formatting --- gpytorch/lazy/lazy_tensor.py | 110 +++++++---------------------------- 1 file changed, 21 insertions(+), 89 deletions(-) diff --git a/gpytorch/lazy/lazy_tensor.py b/gpytorch/lazy/lazy_tensor.py index 730e66fce..839658df4 100644 --- a/gpytorch/lazy/lazy_tensor.py +++ b/gpytorch/lazy/lazy_tensor.py @@ -19,28 +19,13 @@ from ..functions._matmul import Matmul from ..functions._root_decomposition import RootDecomposition from ..functions._sqrt_inv_matmul import SqrtInvMatmul -from ..utils.broadcasting import ( - _matmul_broadcast_shape, - _mul_broadcast_shape, - _to_helper, -) +from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper from ..utils.cholesky import psd_safe_cholesky from ..utils.deprecation import _deprecate_renamed_methods from ..utils.errors import CachingError -from ..utils.getitem import ( - _compute_getitem_size, - _convert_indices_to_tensors, - _is_noop_index, - _noop_index, -) +from ..utils.getitem import _compute_getitem_size, _convert_indices_to_tensors, _is_noop_index, _noop_index from ..utils.lanczos import _postprocess_lanczos_root_inv_decomp -from ..utils.memoize import ( - _is_in_cache_ignore_all_args, - _is_in_cache_ignore_args, - add_to_cache, - cached, - pop_from_cache, -) +from ..utils.memoize import _is_in_cache_ignore_all_args, _is_in_cache_ignore_args, add_to_cache, cached, pop_from_cache from ..utils.pinverse import stable_pinverse from ..utils.pivoted_cholesky import pivoted_cholesky from ..utils.warnings import NumericalWarning @@ -261,11 +246,7 @@ def _getitem(self, row_index, col_index, *batch_indices): from . import InterpolatedLazyTensor res = InterpolatedLazyTensor( - self, - row_interp_indices, - row_interp_values, - col_interp_indices, - col_interp_values, + self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values, ) return res._getitem(row_index, col_index, *batch_indices) @@ -339,11 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices): res = ( InterpolatedLazyTensor( - base_lazy_tensor, - row_interp_indices, - row_interp_values, - col_interp_indices, - col_interp_values, + base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values, ) .evaluate() .squeeze(-2) @@ -543,10 +520,7 @@ def _mul_matrix(self, other): else: left_lazy_tensor = self if self._root_decomposition_size() < other._root_decomposition_size() else other right_lazy_tensor = other if left_lazy_tensor is self else self - return MulLazyTensor( - left_lazy_tensor.root_decomposition(), - right_lazy_tensor.root_decomposition(), - ) + return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition(),) def _preconditioner(self): """ @@ -587,10 +561,7 @@ def _prod_batch(self, dim): shape = list(roots.shape) shape[dim] = 1 extra_root = torch.full( - shape, - dtype=self.dtype, - device=self.device, - fill_value=(1.0 / math.sqrt(self.size(-2))), + shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2))), ) roots = torch.cat([roots, extra_root], dim) num_batch += 1 @@ -767,12 +738,7 @@ def add_jitter(self, jitter_val=1e-3): return self.add_diag(diag) def cat_rows( - self, - cross_mat, - new_mat, - generate_roots=True, - generate_inv_roots=True, - **root_decomp_kwargs, + self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs, ): """ Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g. @@ -815,8 +781,7 @@ def cat_rows( if not generate_roots and generate_inv_roots: warnings.warn( - "root_inv_decomposition is only generated when " "root_decomposition is generated.", - UserWarning, + "root_inv_decomposition is only generated when " "root_decomposition is generated.", UserWarning, ) B_, B = cross_mat, lazify(cross_mat) D = lazify(new_mat) @@ -835,11 +800,7 @@ def cat_rows( # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition # don't create one has_roots = any( - _is_in_cache_ignore_args(self, key) - for key in ( - "root_decomposition", - "root_inv_decomposition", - ) + _is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition",) ) if not generate_roots and not has_roots: return new_lazy_tensor @@ -870,9 +831,7 @@ def cat_rows( # otherwise we use the pseudo-inverse of Z as new inv root new_inv_root = stable_pinverse(new_root).transpose(-2, -1) add_to_cache( - new_lazy_tensor, - "root_inv_decomposition", - RootLazyTensor(lazify(new_inv_root)), + new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)), ) add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root))) @@ -917,8 +876,7 @@ def add_low_rank( new_lazy_tensor = self + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))) else: new_lazy_tensor = SumLazyTensor( - *self.lazy_tensors, - lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))), + *self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))), ) # return as a nonlazy tensor if small enough to reduce memory overhead @@ -966,12 +924,7 @@ def add_low_rank( updated_root = torch.cat( ( current_root.evaluate(), - torch.zeros( - *current_root.shape[:-1], - 1, - device=current_root.device, - dtype=current_root.dtype, - ), + torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype,), ), dim=-1, ) @@ -1231,13 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None): if left_tensor is None: return func.apply(self.representation_tree(), False, right_tensor, *self.representation()) else: - return func.apply( - self.representation_tree(), - True, - left_tensor, - right_tensor, - *self.representation(), - ) + return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation(),) def inv_quad(self, tensor, reduce_inv_quad=True): """ @@ -1304,11 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) will_need_cholesky = False if will_need_cholesky: cholesky = CholLazyTensor(TriangularLazyTensor(self.cholesky())) - return cholesky.inv_quad_logdet( - inv_quad_rhs=inv_quad_rhs, - logdet=logdet, - reduce_inv_quad=reduce_inv_quad, - ) + return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad,) # Default: use modified batch conjugate gradients to compute these terms # See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165 @@ -1700,8 +1643,7 @@ def root_decomposition(self, method: Optional[str] = None): return CholLazyTensor(res) except RuntimeError as e: warnings.warn( - f"Runtime Error when computing Cholesky decomposition: {e}. Using symeig method.", - NumericalWarning, + f"Runtime Error when computing Cholesky decomposition: {e}. Using symeig method.", NumericalWarning, ) method = "symeig" @@ -2056,11 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples): if settings.ciq_samples.on(): base_samples = torch.randn( - *self.batch_shape, - self.size(-1), - num_samples, - dtype=self.dtype, - device=self.device, + *self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device, ) base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous() base_samples = base_samples.unsqueeze(-1) @@ -2080,11 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples): covar_root = self.root_decomposition().root base_samples = torch.randn( - *self.batch_shape, - covar_root.size(-1), - num_samples, - dtype=self.dtype, - device=self.device, + *self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device, ) samples = covar_root.matmul(base_samples).permute(-1, *range(self.dim() - 1)).contiguous() @@ -2205,11 +2139,9 @@ def __getitem__(self, index): # Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices if row_col_are_absorbed: # Convert all indices into tensor indices - ( - *batch_indices, - row_index, - col_index, - ) = _convert_indices_to_tensors(self, (*batch_indices, row_index, col_index)) + (*batch_indices, row_index, col_index,) = _convert_indices_to_tensors( + self, (*batch_indices, row_index, col_index) + ) res = self._get_indices(row_index, col_index, *batch_indices) else: res = self._getitem(row_index, col_index, *batch_indices)