Skip to content

Commit

Permalink
Merge pull request #1752 from sdaulton/cat_row_skip_root_inv
Browse files Browse the repository at this point in the history
add generate_inv_roots option to cat_rows
  • Loading branch information
sdaulton authored Sep 15, 2021
2 parents 05ebbf2 + e36ca9b commit f06004e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 34 deletions.
76 changes: 43 additions & 33 deletions gpytorch/lazy/lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def _getitem(self, row_index, col_index, *batch_indices):
# Construct interpolated LazyTensor
from . import InterpolatedLazyTensor

res = InterpolatedLazyTensor(self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values)
res = InterpolatedLazyTensor(
self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values,
)
return res._getitem(row_index, col_index, *batch_indices)

def _unsqueeze_batch(self, dim):
Expand Down Expand Up @@ -318,7 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices):

res = (
InterpolatedLazyTensor(
base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values
base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values,
)
.evaluate()
.squeeze(-2)
Expand Down Expand Up @@ -518,7 +520,7 @@ def _mul_matrix(self, other):
else:
left_lazy_tensor = self if self._root_decomposition_size() < other._root_decomposition_size() else other
right_lazy_tensor = other if left_lazy_tensor is self else self
return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition())
return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition(),)

def _preconditioner(self):
"""
Expand Down Expand Up @@ -559,7 +561,7 @@ def _prod_batch(self, dim):
shape = list(roots.shape)
shape[dim] = 1
extra_root = torch.full(
shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2)))
shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2))),
)
roots = torch.cat([roots, extra_root], dim)
num_batch += 1
Expand Down Expand Up @@ -735,7 +737,9 @@ def add_jitter(self, jitter_val=1e-3):
diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device)
return self.add_diag(diag)

def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs):
def cat_rows(
self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs,
):
"""
Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat)
Expand All @@ -762,8 +766,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
If :math:`A` is n x n, then this matrix should be n x k.
new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`.
If :math:`B` is n x k, then this matrix should be k x k.
generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it
has not been created yet.
generate_roots (:obj:`bool`): whether to generate the root
decomposition of :math:`A` even if it has not been created yet.
generate_inv_roots (:obj:`bool`): whether to generate the root inv
decomposition of :math:`A` even if it has not been created yet.
Returns:
:obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns.
Expand All @@ -773,6 +779,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
from .root_lazy_tensor import RootLazyTensor
from .triangular_lazy_tensor import TriangularLazyTensor

if not generate_roots and generate_inv_roots:
warnings.warn(
"root_inv_decomposition is only generated when " "root_decomposition is generated.", UserWarning,
)
B_, B = cross_mat, lazify(cross_mat)
D = lazify(new_mat)
batch_shape = B.shape[:-2]
Expand All @@ -789,13 +799,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs

# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
# don't create one
does_not_have_roots = any(
_is_in_cache_ignore_args(self, key) for key in ("root_inv_decomposition", "root_inv_decomposition")
has_roots = any(
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition",)
)
if not generate_roots and not does_not_have_roots:
if not generate_roots and not has_roots:
return new_lazy_tensor

# Get compomnents for new root Z = [E 0; F G]
# Get components for new root Z = [E 0; F G]
E = self.root_decomposition(**root_decomp_kwargs).root # E = L, LL^T = A
m, n = E.shape[-2:]
R = self.root_inv_decomposition().root.evaluate() # RR^T = A^{-1} (this is fast if L is triangular)
Expand All @@ -809,20 +819,22 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
new_root[..., :m, :n] = E.evaluate()
new_root[..., m:, : lower_left.shape[-1]] = lower_left
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root

if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor):
# make sure these are actually upper triangular
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
raise NotImplementedError
# in this case we know new_root is triangular as well
new_root = TriangularLazyTensor(new_root)
new_inv_root = new_root.inverse().transpose(-1, -2)
else:
# otherwise we use the pseudo-inverse of Z as new inv root
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
if generate_inv_roots:
if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor):
# make sure these are actually upper triangular
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
raise NotImplementedError
# in this case we know new_root is triangular as well
new_root = TriangularLazyTensor(new_root)
new_inv_root = new_root.inverse().transpose(-1, -2)
else:
# otherwise we use the pseudo-inverse of Z as new inv root
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
add_to_cache(
new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)),
)

add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root)))
add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)))

return new_lazy_tensor

Expand Down Expand Up @@ -864,7 +876,7 @@ def add_low_rank(
new_lazy_tensor = self + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)))
else:
new_lazy_tensor = SumLazyTensor(
*self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)))
*self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))),
)

# return as a nonlazy tensor if small enough to reduce memory overhead
Expand All @@ -873,10 +885,8 @@ def add_low_rank(

# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
# don't create one
does_not_have_roots = any(
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition")
)
if not generate_roots and not does_not_have_roots:
has_roots = any(_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition"))
if not generate_roots and not has_roots:
return new_lazy_tensor

# we are going to compute the following
Expand Down Expand Up @@ -914,7 +924,7 @@ def add_low_rank(
updated_root = torch.cat(
(
current_root.evaluate(),
torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype),
torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype,),
),
dim=-1,
)
Expand Down Expand Up @@ -1174,7 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None):
if left_tensor is None:
return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
else:
return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation())
return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation(),)

def inv_quad(self, tensor, reduce_inv_quad=True):
"""
Expand Down Expand Up @@ -1241,7 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
will_need_cholesky = False
if will_need_cholesky:
cholesky = CholLazyTensor(TriangularLazyTensor(self.cholesky()))
return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad)
return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad,)

# Default: use modified batch conjugate gradients to compute these terms
# See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165
Expand Down Expand Up @@ -1988,7 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples):

if settings.ciq_samples.on():
base_samples = torch.randn(
*self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device
*self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device,
)
base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous()
base_samples = base_samples.unsqueeze(-1)
Expand All @@ -2008,7 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples):
covar_root = self.root_decomposition().root

base_samples = torch.randn(
*self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device
*self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device,
)
samples = covar_root.matmul(base_samples).permute(-1, *range(self.dim() - 1)).contiguous()

Expand Down
10 changes: 9 additions & 1 deletion gpytorch/test/lazy_tensor_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import gpytorch
from gpytorch.settings import linalg_dtypes
from gpytorch.utils.cholesky import CHOLESKY_METHOD
from gpytorch.utils.errors import CachingError
from gpytorch.utils.memoize import get_from_cache

from .base_test_case import BaseTestCase

Expand Down Expand Up @@ -457,13 +459,19 @@ def test_cat_rows(self):
root_rhs = new_lt.root_decomposition().matmul(rhs)
self.assertAllClose(root_rhs, concat_rhs, **self.tolerances["root_decomposition"])

# check that root inv is cached
root_inv = get_from_cache(new_lt, "root_inv_decomposition")
# check that the inverse root decomposition is close
concat_solve = torch.linalg.solve(concatenated_lt, rhs.unsqueeze(-1)).squeeze(-1)
root_inv_solve = new_lt.root_inv_decomposition().matmul(rhs)
root_inv_solve = root_inv.matmul(rhs)
self.assertLess(
(root_inv_solve - concat_solve).norm() / concat_solve.norm(),
self.tolerances["root_inv_decomposition"]["rtol"],
)
# test generate_inv_roots=False
new_lt = lazy_tensor.cat_rows(new_rows, new_point, generate_inv_roots=False)
with self.assertRaises(CachingError):
get_from_cache(new_lt, "root_inv_decomposition")

def test_cholesky(self):
lazy_tensor = self.create_lazy_tensor()
Expand Down

0 comments on commit f06004e

Please sign in to comment.