diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index 6da841a7f..e0341398a 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -203,12 +203,12 @@ def get_base_samples(self, sample_shape=torch.Size()): return base_samples.view(new_shape).transpose(-1, -2).contiguous() return base_samples.view(*sample_shape, *self._output_shape) - def log_prob(self, value): + def log_prob_terms(self, value): if not self._interleaved: # flip shape of last two dimensions new_shape = value.shape[:-2] + value.shape[:-3:-1] value = value.view(new_shape).transpose(-1, -2).contiguous() - return super().log_prob(value.view(*value.shape[:-2], -1)) + return super().log_prob_terms(value.view(*value.shape[:-2], -1)) @property def mean(self): diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 06b8ad6f5..d75d01057 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -48,7 +48,11 @@ def __init__(self, mean, covariance_matrix, validate_args=False): # TODO: Integrate argument validation for LazyTensors into torch.distribution validation logic super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False) else: - super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args) + super().__init__( + loc=mean, + covariance_matrix=covariance_matrix, + validate_args=validate_args, + ) @property def _unbroadcasted_scale_tril(self): @@ -142,10 +146,7 @@ def lazy_covariance_matrix(self): else: return lazify(super().covariance_matrix) - def log_prob(self, value): - if settings.fast_computations.log_prob.off(): - return super().log_prob(value) - + def log_prob_terms(self, value): if self._validate_args: self._validate_sample(value) @@ -157,7 +158,10 @@ def log_prob(self, value): if len(diff.shape[:-1]) < len(covar.batch_shape): diff = diff.expand(covar.shape[:-1]) else: - padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape) + padded_batch_shape = ( + *(1 for _ in range(diff.dim() + 1 - covar.dim())), + *covar.batch_shape, + ) covar = covar.repeat( *(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)), 1, @@ -167,9 +171,18 @@ def log_prob(self, value): # Get log determininant and first part of quadratic form covar = covar.evaluate_kernel() inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) + norm_const = torch.tensor(diff.size(-1) * math.log(2 * math.pi)).to(inv_quad) - res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)]) - return res + split_terms = [inv_quad, logdet, norm_const] + split_terms = [-0.5 * term for term in split_terms] + + return split_terms + + def log_prob(self, value): + if settings.fast_computations.log_prob.off(): + return super().log_prob(value) + split_terms = self.log_prob_terms(value) + return sum(split_terms) def rsample(self, sample_shape=torch.Size(), base_samples=None): covar = self.lazy_covariance_matrix @@ -286,7 +299,10 @@ def __mul__(self, other): raise RuntimeError("Can only multiply by scalars") if other == 1: return self - return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other ** 2)) + return self.__class__( + mean=self.mean * other, + covariance_matrix=self.lazy_covariance_matrix * (other ** 2), + ) def __truediv__(self, other): return self.__mul__(1.0 / other) @@ -341,5 +357,12 @@ def kl_mvn_mvn(p_dist, q_dist): trace_plus_inv_quad_form, logdet_q_covar = q_covar.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=True) # Compute the KL Divergence. - res = 0.5 * sum([logdet_q_covar, logdet_p_covar.mul(-1), trace_plus_inv_quad_form, -float(mean_diffs.size(-1))]) + res = 0.5 * sum( + [ + logdet_q_covar, + logdet_p_covar.mul(-1), + trace_plus_inv_quad_form, + -float(mean_diffs.size(-1)), + ] + ) return res diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index f33408868..043011c71 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +import torch + from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase from .marginal_log_likelihood import MarginalLogLikelihood @@ -17,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): :param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model :param ~gpytorch.models.ExactGP model: The exact GP model + :param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately Example: >>> # model is a gpytorch.models.ExactGP @@ -28,10 +31,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): >>> loss.backward() """ - def __init__(self, likelihood, model): + def __init__(self, likelihood, model, combine_terms=True): if not isinstance(likelihood, _GaussianLikelihoodBase): raise RuntimeError("Likelihood must be Gaussian for exact inference") - super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) + super(ExactMarginalLogLikelihood, self).__init__(likelihood, model, combine_terms) def _add_other_terms(self, res, params): # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) @@ -59,9 +62,18 @@ def forward(self, function_dist, target, *params): # Get the log prob of the marginal distribution output = self.likelihood(function_dist, *params) - res = output.log_prob(target) - res = self._add_other_terms(res, params) + split_terms = output.log_prob_terms(target) # Scale by the amount of data we have num_data = function_dist.event_shape.numel() - return res.div_(num_data) + + if self.combine_terms: + term_sum = sum(split_terms) + term_sum = self._add_other_terms(term_sum, params) + return term_sum.div(num_data) + else: + norm_const = split_terms[-1] + other_terms = torch.zeros_like(norm_const) + other_terms = self._add_other_terms(other_terms, params) + split_terms.append(other_terms) + return [term.div(num_data) for term in split_terms] diff --git a/gpytorch/mlls/leave_one_out_pseudo_likelihood.py b/gpytorch/mlls/leave_one_out_pseudo_likelihood.py index 6252515f6..5fd111b4d 100644 --- a/gpytorch/mlls/leave_one_out_pseudo_likelihood.py +++ b/gpytorch/mlls/leave_one_out_pseudo_likelihood.py @@ -29,6 +29,7 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood): :param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model :param ~gpytorch.models.ExactGP model: The exact GP model + :param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately Example: >>> # model is a gpytorch.models.ExactGP @@ -40,10 +41,24 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood): >>> loss.backward() """ - def __init__(self, likelihood, model): - super().__init__(likelihood=likelihood, model=model) - self.likelihood = likelihood - self.model = model + def log_prob_terms(self, function_dist, target, *params): + output = self.likelihood(function_dist, *params) + m, L = output.mean, output.lazy_covariance_matrix.cholesky(upper=False) + m = m.reshape(*target.shape) + identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device) + sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K)) + mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2 + + # Scale by the amount of data we have and then add on the scaled constant + num_data = target.size(-1) + data_fit = ((target - mu).pow(2.0) / sigma2).sum(-1) + approx_logdet = sigma2.log().sum(-1) + norm_const = torch.tensor(num_data * math.log(2 * math.pi)).to(approx_logdet) + other_term = self._add_other_terms(torch.zeros_like(approx_logdet), params) + split_terms = [data_fit, approx_logdet, norm_const, other_term] + split_terms = [-0.5 / num_data * term for term in split_terms] + + return split_terms def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> Tensor: r""" @@ -54,18 +69,7 @@ def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> :param torch.Tensor target: :math:`\mathbf y` The target values :param dict kwargs: Additional arguments to pass to the likelihood's :attr:`forward` function. """ - output = self.likelihood(function_dist, *params) - m, L = output.mean, output.lazy_covariance_matrix.cholesky(upper=False) - m = m.reshape(*target.shape) - identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device) - sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K)) - mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2 - term1 = -0.5 * sigma2.log() - term2 = -0.5 * (target - mu).pow(2.0) / sigma2 - res = (term1 + term2).sum(dim=-1) - - res = self._add_other_terms(res, params) - - # Scale by the amount of data we have and then add on the scaled constant - num_data = target.size(-1) - return res.div_(num_data) - 0.5 * math.log(2 * math.pi) + split_terms = self.log_prob_terms(function_dist, target, *params) + if self.combine_terms: + return sum(split_terms) + return split_terms diff --git a/gpytorch/mlls/marginal_log_likelihood.py b/gpytorch/mlls/marginal_log_likelihood.py index be696c9c8..541b2000f 100644 --- a/gpytorch/mlls/marginal_log_likelihood.py +++ b/gpytorch/mlls/marginal_log_likelihood.py @@ -25,7 +25,7 @@ class MarginalLogLikelihood(Module): these functions must be negated for optimization). """ - def __init__(self, likelihood, model): + def __init__(self, likelihood, model, combine_terms=True): super(MarginalLogLikelihood, self).__init__() if not isinstance(model, GP): raise RuntimeError( @@ -35,6 +35,7 @@ def __init__(self, likelihood, model): ) self.likelihood = likelihood self.model = model + self.combine_terms = combine_terms def forward(self, output, target, **kwargs): r""" diff --git a/test/distributions/test_multivariate_normal.py b/test/distributions/test_multivariate_normal.py index 0b342b16f..e3cb82f54 100644 --- a/test/distributions/test_multivariate_normal.py +++ b/test/distributions/test_multivariate_normal.py @@ -219,20 +219,32 @@ def test_log_prob(self, cuda=False): var = torch.randn(4, device=device, dtype=dtype).abs_() values = torch.randn(4, device=device, dtype=dtype) - res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) + mvn = MultivariateNormal(mean, DiagLazyTensor(var)) + res = mvn.log_prob(values) actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values) self.assertLess((res - actual).div(res).abs().item(), 1e-2) + res2 = mvn.log_prob_terms(values) + assert len(res2) == 3 + res2 = sum(res2) + self.assertLess((res2 - actual).div(res).abs().item(), 1e-2) + mean = torch.randn(3, 4, device=device, dtype=dtype) var = torch.randn(3, 4, device=device, dtype=dtype).abs_() values = torch.randn(3, 4, device=device, dtype=dtype) - res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) + mvn = MultivariateNormal(mean, DiagLazyTensor(var)) + res = mvn.log_prob(values) actual = TMultivariateNormal( mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat(3, 1, 1) ).log_prob(values) self.assertLess((res - actual).div(res).abs().norm(), 1e-2) + res2 = mvn.log_prob_terms(values) + assert len(res2) == 3 + res2 = sum(res2) + self.assertLess((res2 - actual).div(res).abs().norm(), 1e-2) + def test_log_prob_cuda(self): if torch.cuda.is_available(): with least_used_cuda_device(): diff --git a/test/mlls/test_exact_marginal_log_likelihood.py b/test/mlls/test_exact_marginal_log_likelihood.py new file mode 100644 index 000000000..c277dfcb0 --- /dev/null +++ b/test/mlls/test_exact_marginal_log_likelihood.py @@ -0,0 +1,57 @@ +import unittest + +import torch + +import gpytorch + +from .test_leave_one_out_pseudo_likelihood import ExactGPModel + + +class TestExactMarginalLogLikelihood(unittest.TestCase): + def get_data(self, shapes, combine_terms, dtype=None, device=None): + train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True) + train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1]) + likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device) + model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device) + exact_mll = gpytorch.mlls.ExactMarginalLogLikelihood( + likelihood=likelihood, model=model, combine_terms=combine_terms + ) + return train_x, train_y, exact_mll + + def test_smoke(self): + """Make sure the exact_mll works without batching.""" + train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=True) + output = exact_mll.model(train_x) + loss = -exact_mll(output, train_y) + loss.backward() + self.assertTrue(train_x.grad is not None) + + train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=False) + output = exact_mll.model(train_x) + mll_out = exact_mll(output, train_y) + loss = -1 * sum(mll_out) + loss.backward() + assert len(mll_out) == 4 + self.assertTrue(train_x.grad is not None) + + def test_smoke_batch(self): + """Make sure the exact_mll works without batching.""" + train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=True) + output = exact_mll.model(train_x) + loss = -exact_mll(output, train_y) + assert loss.shape == (3, 3, 3) + loss.sum().backward() + self.assertTrue(train_x.grad is not None) + + train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=False) + output = exact_mll.model(train_x) + mll_out = exact_mll(output, train_y) + loss = -1 * sum(mll_out) + assert len(mll_out) == 4 + assert loss.shape == (3, 3, 3) + loss.sum().backward() + self.assertTrue(train_x.grad is not None) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/mlls/test_leave_one_out_pseudo_likelihood.py b/test/mlls/test_leave_one_out_pseudo_likelihood.py index 69292e56b..5beac070d 100644 --- a/test/mlls/test_leave_one_out_pseudo_likelihood.py +++ b/test/mlls/test_leave_one_out_pseudo_likelihood.py @@ -21,36 +21,55 @@ def forward(self, x): class TestLeaveOneOutPseudoLikelihood(unittest.TestCase): - def get_data(self, shapes, dtype=None, device=None): + def get_data(self, shapes, combine_terms, dtype=None, device=None): train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True) train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1]) likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device) model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device) - loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood=likelihood, model=model) + loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood( + likelihood=likelihood, model=model, combine_terms=combine_terms + ) return train_x, train_y, loocv def test_smoke(self): """Make sure the loocv works without batching.""" - train_x, train_y, loocv = self.get_data([5, 2]) + train_x, train_y, loocv = self.get_data([5, 2], combine_terms=True) output = loocv.model(train_x) loss = -loocv(output, train_y) loss.backward() self.assertTrue(train_x.grad is not None) + train_x, train_y, loocv = self.get_data([5, 2], combine_terms=False) + output = loocv.model(train_x) + mll_out = loocv(output, train_y) + loss = -1 * sum(mll_out) + loss.backward() + assert len(mll_out) == 4 + self.assertTrue(train_x.grad is not None) + def test_smoke_batch(self): """Make sure the loocv works without batching.""" - train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2]) + train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=True) output = loocv.model(train_x) loss = -loocv(output, train_y) assert loss.shape == (3, 3, 3) loss.sum().backward() self.assertTrue(train_x.grad is not None) + train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=False) + output = loocv.model(train_x) + mll_out = loocv(output, train_y) + loss = -1 * sum(mll_out) + assert len(mll_out) == 4 + assert loss.shape == (3, 3, 3) + loss.sum().backward() + self.assertTrue(train_x.grad is not None) + def test_check_bordered_system(self): """Make sure that the bordered system solves match the naive solution.""" n = 5 # Compute the pseudo-likelihood via the bordered systems in O(n^3) - train_x, train_y, loocv = self.get_data([n, 2], dtype=torch.float64) + train_x, train_y, loocv = self.get_data([n, 2], combine_terms=True, dtype=torch.float64) output = loocv.model(train_x) loocv_1 = loocv(output, train_y)