-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add combine_terms option to exact MLL #1863
base: main
Are you sure you want to change the base?
Changes from all commits
d8bd497
328ebd0
b85418e
d6ca1bf
8069b7e
71ba3bf
2160a7f
e13a318
877f271
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from ..distributions import MultivariateNormal | ||
from ..likelihoods import _GaussianLikelihoodBase | ||
from .marginal_log_likelihood import MarginalLogLikelihood | ||
|
@@ -17,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): | |
|
||
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model | ||
:param ~gpytorch.models.ExactGP model: The exact GP model | ||
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably also describe what happens if there are "other terms" (i.e. that they are added to the return elements) |
||
|
||
Example: | ||
>>> # model is a gpytorch.models.ExactGP | ||
|
@@ -28,10 +31,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): | |
>>> loss.backward() | ||
""" | ||
|
||
def __init__(self, likelihood, model): | ||
def __init__(self, likelihood, model, combine_terms=True): | ||
if not isinstance(likelihood, _GaussianLikelihoodBase): | ||
raise RuntimeError("Likelihood must be Gaussian for exact inference") | ||
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) | ||
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model, combine_terms) | ||
|
||
def _add_other_terms(self, res, params): | ||
# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) | ||
|
@@ -59,9 +62,18 @@ def forward(self, function_dist, target, *params): | |
|
||
# Get the log prob of the marginal distribution | ||
output = self.likelihood(function_dist, *params) | ||
res = output.log_prob(target) | ||
res = self._add_other_terms(res, params) | ||
split_terms = output.log_prob_terms(target) | ||
|
||
# Scale by the amount of data we have | ||
num_data = function_dist.event_shape.numel() | ||
return res.div_(num_data) | ||
|
||
if self.combine_terms: | ||
term_sum = sum(split_terms) | ||
term_sum = self._add_other_terms(term_sum, params) | ||
return term_sum.div(num_data) | ||
else: | ||
norm_const = split_terms[-1] | ||
other_terms = torch.zeros_like(norm_const) | ||
other_terms = self._add_other_terms(other_terms, params) | ||
split_terms.append(other_terms) | ||
return [term.div(num_data) for term in split_terms] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -219,20 +219,32 @@ def test_log_prob(self, cuda=False): | |||||
var = torch.randn(4, device=device, dtype=dtype).abs_() | ||||||
values = torch.randn(4, device=device, dtype=dtype) | ||||||
|
||||||
res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) | ||||||
mvn = MultivariateNormal(mean, DiagLazyTensor(var)) | ||||||
res = mvn.log_prob(values) | ||||||
actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values) | ||||||
self.assertLess((res - actual).div(res).abs().item(), 1e-2) | ||||||
|
||||||
res2 = mvn.log_prob_terms(values) | ||||||
assert len(res2) == 3 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also in other places in the tests below |
||||||
res2 = sum(res2) | ||||||
self.assertLess((res2 - actual).div(res).abs().item(), 1e-2) | ||||||
|
||||||
mean = torch.randn(3, 4, device=device, dtype=dtype) | ||||||
var = torch.randn(3, 4, device=device, dtype=dtype).abs_() | ||||||
values = torch.randn(3, 4, device=device, dtype=dtype) | ||||||
|
||||||
res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) | ||||||
mvn = MultivariateNormal(mean, DiagLazyTensor(var)) | ||||||
res = mvn.log_prob(values) | ||||||
actual = TMultivariateNormal( | ||||||
mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat(3, 1, 1) | ||||||
).log_prob(values) | ||||||
self.assertLess((res - actual).div(res).abs().norm(), 1e-2) | ||||||
|
||||||
res2 = mvn.log_prob_terms(values) | ||||||
assert len(res2) == 3 | ||||||
res2 = sum(res2) | ||||||
self.assertLess((res2 - actual).div(res).abs().norm(), 1e-2) | ||||||
|
||||||
def test_log_prob_cuda(self): | ||||||
if torch.cuda.is_available(): | ||||||
with least_used_cuda_device(): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
import gpytorch | ||
|
||
from .test_leave_one_out_pseudo_likelihood import ExactGPModel | ||
|
||
|
||
class TestExactMarginalLogLikelihood(unittest.TestCase): | ||
def get_data(self, shapes, combine_terms, dtype=None, device=None): | ||
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True) | ||
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1]) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device) | ||
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device) | ||
exact_mll = gpytorch.mlls.ExactMarginalLogLikelihood( | ||
likelihood=likelihood, model=model, combine_terms=combine_terms | ||
) | ||
return train_x, train_y, exact_mll | ||
|
||
def test_smoke(self): | ||
"""Make sure the exact_mll works without batching.""" | ||
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=True) | ||
output = exact_mll.model(train_x) | ||
loss = -exact_mll(output, train_y) | ||
loss.backward() | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=False) | ||
output = exact_mll.model(train_x) | ||
mll_out = exact_mll(output, train_y) | ||
loss = -1 * sum(mll_out) | ||
loss.backward() | ||
assert len(mll_out) == 4 | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
def test_smoke_batch(self): | ||
"""Make sure the exact_mll works without batching.""" | ||
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=True) | ||
output = exact_mll.model(train_x) | ||
loss = -exact_mll(output, train_y) | ||
assert loss.shape == (3, 3, 3) | ||
loss.sum().backward() | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=False) | ||
output = exact_mll.model(train_x) | ||
mll_out = exact_mll(output, train_y) | ||
loss = -1 * sum(mll_out) | ||
assert len(mll_out) == 4 | ||
assert loss.shape == (3, 3, 3) | ||
loss.sum().backward() | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.