Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add combine_terms option to exact MLL #1863

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def get_base_samples(self, sample_shape=torch.Size()):
return base_samples.view(new_shape).transpose(-1, -2).contiguous()
return base_samples.view(*sample_shape, *self._output_shape)

def log_prob(self, value):
def log_prob_terms(self, value):
if not self._interleaved:
# flip shape of last two dimensions
new_shape = value.shape[:-2] + value.shape[:-3:-1]
value = value.view(new_shape).transpose(-1, -2).contiguous()
return super().log_prob(value.view(*value.shape[:-2], -1))
return super().log_prob_terms(value.view(*value.shape[:-2], -1))

@property
def mean(self):
Expand Down
43 changes: 33 additions & 10 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def __init__(self, mean, covariance_matrix, validate_args=False):
# TODO: Integrate argument validation for LazyTensors into torch.distribution validation logic
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
else:
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
super().__init__(
loc=mean,
covariance_matrix=covariance_matrix,
validate_args=validate_args,
)

@property
def _unbroadcasted_scale_tril(self):
Expand Down Expand Up @@ -142,10 +146,7 @@ def lazy_covariance_matrix(self):
else:
return lazify(super().covariance_matrix)

def log_prob(self, value):
if settings.fast_computations.log_prob.off():
return super().log_prob(value)

def log_prob_terms(self, value):
if self._validate_args:
self._validate_sample(value)

Expand All @@ -157,7 +158,10 @@ def log_prob(self, value):
if len(diff.shape[:-1]) < len(covar.batch_shape):
diff = diff.expand(covar.shape[:-1])
else:
padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape)
padded_batch_shape = (
*(1 for _ in range(diff.dim() + 1 - covar.dim())),
*covar.batch_shape,
)
covar = covar.repeat(
*(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)),
1,
Expand All @@ -167,9 +171,18 @@ def log_prob(self, value):
# Get log determininant and first part of quadratic form
covar = covar.evaluate_kernel()
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
norm_const = torch.tensor(diff.size(-1) * math.log(2 * math.pi)).to(inv_quad)

res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
return res
split_terms = [inv_quad, logdet, norm_const]
split_terms = [-0.5 * term for term in split_terms]
Comment on lines +176 to +177
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
split_terms = [inv_quad, logdet, norm_const]
split_terms = [-0.5 * term for term in split_terms]
split_terms = [-0.5 * inv_quad, logdet, -0.5 * norm_const]


return split_terms

def log_prob(self, value):
if settings.fast_computations.log_prob.off():
return super().log_prob(value)
split_terms = self.log_prob_terms(value)
return sum(split_terms)

def rsample(self, sample_shape=torch.Size(), base_samples=None):
covar = self.lazy_covariance_matrix
Expand Down Expand Up @@ -286,7 +299,10 @@ def __mul__(self, other):
raise RuntimeError("Can only multiply by scalars")
if other == 1:
return self
return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other ** 2))
return self.__class__(
mean=self.mean * other,
covariance_matrix=self.lazy_covariance_matrix * (other ** 2),
)

def __truediv__(self, other):
return self.__mul__(1.0 / other)
Expand Down Expand Up @@ -341,5 +357,12 @@ def kl_mvn_mvn(p_dist, q_dist):
trace_plus_inv_quad_form, logdet_q_covar = q_covar.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=True)

# Compute the KL Divergence.
res = 0.5 * sum([logdet_q_covar, logdet_p_covar.mul(-1), trace_plus_inv_quad_form, -float(mean_diffs.size(-1))])
res = 0.5 * sum(
[
logdet_q_covar,
logdet_p_covar.mul(-1),
trace_plus_inv_quad_form,
-float(mean_diffs.size(-1)),
]
)
return res
22 changes: 17 additions & 5 deletions gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import torch

from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood
Expand All @@ -17,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):

:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably also describe what happens if there are "other terms" (i.e. that they are added to the return elements)


Example:
>>> # model is a gpytorch.models.ExactGP
Expand All @@ -28,10 +31,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):
>>> loss.backward()
"""

def __init__(self, likelihood, model):
def __init__(self, likelihood, model, combine_terms=True):
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("Likelihood must be Gaussian for exact inference")
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model, combine_terms)

def _add_other_terms(self, res, params):
# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
Expand Down Expand Up @@ -59,9 +62,18 @@ def forward(self, function_dist, target, *params):

# Get the log prob of the marginal distribution
output = self.likelihood(function_dist, *params)
res = output.log_prob(target)
res = self._add_other_terms(res, params)
split_terms = output.log_prob_terms(target)

# Scale by the amount of data we have
num_data = function_dist.event_shape.numel()
return res.div_(num_data)

if self.combine_terms:
term_sum = sum(split_terms)
term_sum = self._add_other_terms(term_sum, params)
return term_sum.div(num_data)
else:
norm_const = split_terms[-1]
other_terms = torch.zeros_like(norm_const)
other_terms = self._add_other_terms(other_terms, params)
split_terms.append(other_terms)
return [term.div(num_data) for term in split_terms]
42 changes: 23 additions & 19 deletions gpytorch/mlls/leave_one_out_pseudo_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):

:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately

Example:
>>> # model is a gpytorch.models.ExactGP
Expand All @@ -40,10 +41,24 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):
>>> loss.backward()
"""

def __init__(self, likelihood, model):
super().__init__(likelihood=likelihood, model=model)
self.likelihood = likelihood
self.model = model
def log_prob_terms(self, function_dist, target, *params):
output = self.likelihood(function_dist, *params)
m, L = output.mean, output.lazy_covariance_matrix.cholesky(upper=False)
m = m.reshape(*target.shape)
identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device)
sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K))
mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2

# Scale by the amount of data we have and then add on the scaled constant
num_data = target.size(-1)
data_fit = ((target - mu).pow(2.0) / sigma2).sum(-1)
approx_logdet = sigma2.log().sum(-1)
norm_const = torch.tensor(num_data * math.log(2 * math.pi)).to(approx_logdet)
other_term = self._add_other_terms(torch.zeros_like(approx_logdet), params)
split_terms = [data_fit, approx_logdet, norm_const, other_term]
split_terms = [-0.5 / num_data * term for term in split_terms]

return split_terms

def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> Tensor:
r"""
Expand All @@ -54,18 +69,7 @@ def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) ->
:param torch.Tensor target: :math:`\mathbf y` The target values
:param dict kwargs: Additional arguments to pass to the likelihood's :attr:`forward` function.
"""
output = self.likelihood(function_dist, *params)
m, L = output.mean, output.lazy_covariance_matrix.cholesky(upper=False)
m = m.reshape(*target.shape)
identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device)
sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K))
mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2
term1 = -0.5 * sigma2.log()
term2 = -0.5 * (target - mu).pow(2.0) / sigma2
res = (term1 + term2).sum(dim=-1)

res = self._add_other_terms(res, params)

# Scale by the amount of data we have and then add on the scaled constant
num_data = target.size(-1)
return res.div_(num_data) - 0.5 * math.log(2 * math.pi)
split_terms = self.log_prob_terms(function_dist, target, *params)
if self.combine_terms:
return sum(split_terms)
return split_terms
3 changes: 2 additions & 1 deletion gpytorch/mlls/marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MarginalLogLikelihood(Module):
these functions must be negated for optimization).
"""

def __init__(self, likelihood, model):
def __init__(self, likelihood, model, combine_terms=True):
super(MarginalLogLikelihood, self).__init__()
if not isinstance(model, GP):
raise RuntimeError(
Expand All @@ -35,6 +35,7 @@ def __init__(self, likelihood, model):
)
self.likelihood = likelihood
self.model = model
self.combine_terms = combine_terms

def forward(self, output, target, **kwargs):
r"""
Expand Down
16 changes: 14 additions & 2 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,32 @@ def test_log_prob(self, cuda=False):
var = torch.randn(4, device=device, dtype=dtype).abs_()
values = torch.randn(4, device=device, dtype=dtype)

res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values)
mvn = MultivariateNormal(mean, DiagLazyTensor(var))
res = mvn.log_prob(values)
actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values)
self.assertLess((res - actual).div(res).abs().item(), 1e-2)

res2 = mvn.log_prob_terms(values)
assert len(res2) == 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert len(res2) == 3
self.assertEqual(len(res2), 3)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also in other places in the tests below

res2 = sum(res2)
self.assertLess((res2 - actual).div(res).abs().item(), 1e-2)

mean = torch.randn(3, 4, device=device, dtype=dtype)
var = torch.randn(3, 4, device=device, dtype=dtype).abs_()
values = torch.randn(3, 4, device=device, dtype=dtype)

res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values)
mvn = MultivariateNormal(mean, DiagLazyTensor(var))
res = mvn.log_prob(values)
actual = TMultivariateNormal(
mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat(3, 1, 1)
).log_prob(values)
self.assertLess((res - actual).div(res).abs().norm(), 1e-2)

res2 = mvn.log_prob_terms(values)
assert len(res2) == 3
res2 = sum(res2)
self.assertLess((res2 - actual).div(res).abs().norm(), 1e-2)

def test_log_prob_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
Expand Down
57 changes: 57 additions & 0 deletions test/mlls/test_exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest

import torch

import gpytorch

from .test_leave_one_out_pseudo_likelihood import ExactGPModel


class TestExactMarginalLogLikelihood(unittest.TestCase):
def get_data(self, shapes, combine_terms, dtype=None, device=None):
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True)
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1])
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device)
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device)
exact_mll = gpytorch.mlls.ExactMarginalLogLikelihood(
likelihood=likelihood, model=model, combine_terms=combine_terms
)
return train_x, train_y, exact_mll

def test_smoke(self):
"""Make sure the exact_mll works without batching."""
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=True)
output = exact_mll.model(train_x)
loss = -exact_mll(output, train_y)
loss.backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=False)
output = exact_mll.model(train_x)
mll_out = exact_mll(output, train_y)
loss = -1 * sum(mll_out)
loss.backward()
assert len(mll_out) == 4
self.assertTrue(train_x.grad is not None)

def test_smoke_batch(self):
"""Make sure the exact_mll works without batching."""
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=True)
output = exact_mll.model(train_x)
loss = -exact_mll(output, train_y)
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=False)
output = exact_mll.model(train_x)
mll_out = exact_mll(output, train_y)
loss = -1 * sum(mll_out)
assert len(mll_out) == 4
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)


if __name__ == "__main__":
unittest.main()
29 changes: 24 additions & 5 deletions test/mlls/test_leave_one_out_pseudo_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,55 @@ def forward(self, x):


class TestLeaveOneOutPseudoLikelihood(unittest.TestCase):
def get_data(self, shapes, dtype=None, device=None):
def get_data(self, shapes, combine_terms, dtype=None, device=None):
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True)
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1])
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device)
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device)
loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood=likelihood, model=model)
loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
likelihood=likelihood, model=model, combine_terms=combine_terms
)
return train_x, train_y, loocv

def test_smoke(self):
"""Make sure the loocv works without batching."""
train_x, train_y, loocv = self.get_data([5, 2])
train_x, train_y, loocv = self.get_data([5, 2], combine_terms=True)
output = loocv.model(train_x)
loss = -loocv(output, train_y)
loss.backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, loocv = self.get_data([5, 2], combine_terms=False)
output = loocv.model(train_x)
mll_out = loocv(output, train_y)
loss = -1 * sum(mll_out)
loss.backward()
assert len(mll_out) == 4
self.assertTrue(train_x.grad is not None)

def test_smoke_batch(self):
"""Make sure the loocv works without batching."""
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2])
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=True)
output = loocv.model(train_x)
loss = -loocv(output, train_y)
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=False)
output = loocv.model(train_x)
mll_out = loocv(output, train_y)
loss = -1 * sum(mll_out)
assert len(mll_out) == 4
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)

def test_check_bordered_system(self):
"""Make sure that the bordered system solves match the naive solution."""
n = 5
# Compute the pseudo-likelihood via the bordered systems in O(n^3)
train_x, train_y, loocv = self.get_data([n, 2], dtype=torch.float64)
train_x, train_y, loocv = self.get_data([n, 2], combine_terms=True, dtype=torch.float64)
output = loocv.model(train_x)
loocv_1 = loocv(output, train_y)

Expand Down