-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #3255 (draft) #3265
base: dev
Are you sure you want to change the base?
Fix #3255 (draft) #3265
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ def _compute_log_r(model_trace, guide_trace): | |
for name, model_site in model_trace.nodes.items(): | ||
if model_site["type"] == "sample": | ||
log_r_term = model_site["log_prob"] | ||
if not model_site["is_observed"]: | ||
if not model_site["is_observed"] and name in guide_trace.nodes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I may be forgetting something, but I thought There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you and I cannot think of useful cases of this. The point I make in the issue is that this triggers a warning when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, thanks for explaining, I think I better understand now. I'd like to hear what other folks think (@martinjankowiak @eb8680 @fehiepsi). One argument for erroring more often is that there is a lot code in Pyro that tacitly assumes all sites are either observed or guided. I'm not sure what that code is, since we've only tacitly made that assumption, but it's worth thinking about: reparametrizers, One argument for allowing "partial" guides is that it's just more general. But if we decide to support "partial" guides throughout Pyro, I think we'll need to adopt importance sampling semantics, so we'd need to replace pyro's basic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. For context, this happened to me on sites that are implicitly created by Pyro in the model (and that are therefore not in the guide), and that subsequently caused a failure because they were in the case |
||
log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] | ||
log_r.add((stacks[name], log_r_term.detach())) | ||
return log_r | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,7 +108,7 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): | |
if model_site["type"] == "sample": | ||
if model_site["is_observed"]: | ||
elbo_particle = elbo_particle + model_site["log_prob_sum"] | ||
else: | ||
elif name in guide_trace.nodes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: this should never happen |
||
guide_site = guide_trace.nodes[name] | ||
if is_validation_enabled(): | ||
check_fully_reparametrized(guide_site) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -218,6 +218,8 @@ def _compute_elbo(model_trace, guide_trace): | |
# we include only downstream costs to reduce variance | ||
# optionally include baselines to further reduce variance | ||
for node, downstream_cost in downstream_costs.items(): | ||
if node not in guide_trace.nodes: | ||
continue | ||
Comment on lines
+221
to
+222
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: this should never happen |
||
guide_site = guide_trace.nodes[node] | ||
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"]) | ||
score_function = guide_site["score_parts"].score_function | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
import pytest | ||
|
@@ -214,6 +215,102 @@ def guide(subsample): | |
assert_equal(actual_grads, expected_grads, prec=precision) | ||
|
||
|
||
# Not including the unobserved site in the guide triggers a warning | ||
# that can make the test fail if we do not deactivate UserWarning. | ||
@pytest.mark.filterwarnings("ignore::UserWarning") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: this should never happen |
||
@pytest.mark.parametrize( | ||
"with_x_unobserved", | ||
[True, False], | ||
) | ||
@pytest.mark.parametrize( | ||
"mask", | ||
[[True, True], [True, False], [False, True]], | ||
) | ||
@pytest.mark.parametrize( | ||
"reparameterized,has_rsample", | ||
[(True, None), (True, False), (True, True), (False, None)], | ||
ids=["reparam", "reparam-False", "reparam-True", "nonreparam"], | ||
) | ||
@pytest.mark.parametrize( | ||
"Elbo,local_samples", | ||
[ | ||
(Trace_ELBO, False), | ||
(DiffTrace_ELBO, False), | ||
(TraceGraph_ELBO, False), | ||
(TraceMeanField_ELBO, False), | ||
(TraceEnum_ELBO, False), | ||
(TraceEnum_ELBO, True), | ||
], | ||
) | ||
def test_mask_gradient( | ||
Elbo, | ||
reparameterized, | ||
has_rsample, | ||
local_samples, | ||
mask, | ||
with_x_unobserved, | ||
): | ||
pyro.clear_param_store() | ||
data = torch.tensor([-0.5, 2.0]) | ||
precision = 0.08 | ||
Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal | ||
|
||
def model(data, mask): | ||
z = pyro.sample("z", Normal(0, 1)) | ||
with pyro.plate("data", len(data)): | ||
pyro.sample("x", Normal(z, 1), obs=data, obs_mask=mask) | ||
|
||
def guide(data, mask): | ||
scale = pyro.param("scale", lambda: torch.tensor([1.0])) | ||
loc = pyro.param("loc", lambda: torch.tensor([1.0])) | ||
z_dist = Normal(loc, scale) | ||
if has_rsample is not None: | ||
z_dist.has_rsample_(has_rsample) | ||
z = pyro.sample("z", z_dist) | ||
if with_x_unobserved: | ||
with pyro.plate("data", len(data)): | ||
with pyro.poutine.mask(mask=~mask): | ||
pyro.sample("x_unobserved", Normal(z, 1)) | ||
|
||
num_particles = 50000 | ||
accumulation = 1 | ||
if local_samples: | ||
# One has to limit the amount of samples in this | ||
# test because the memory footprint is large. | ||
guide = config_enumerate(guide, num_samples=5000) | ||
accumulation = num_particles // 5000 | ||
num_particles = 1 | ||
|
||
optim = Adam({"lr": 0.1}) | ||
elbo = Elbo( | ||
max_plate_nesting=1, # set this to ensure rng agrees across runs | ||
num_particles=num_particles, | ||
vectorize_particles=True, | ||
strict_enumeration_warning=False, | ||
) | ||
actual_grads = defaultdict(lambda: np.zeros(1)) | ||
for _ in range(accumulation): | ||
inference = SVI(model, guide, optim, loss=elbo) | ||
with xfail_if_not_implemented(): | ||
inference.loss_and_grads(model, guide, data=data, mask=torch.tensor(mask)) | ||
params = dict(pyro.get_param_store().named_parameters()) | ||
actual_grads = { | ||
name: param.grad.detach().cpu().numpy() / accumulation | ||
for name, param in params.items() | ||
} | ||
|
||
# grad(loc) = (n+1) * loc - (x1 + ... + xn) | ||
# grad(scale) = (n+1) * scale - 1 / scale | ||
expected_grads = { | ||
"loc": sum(mask) + 1.0 - data[mask].sum(0, keepdim=True).numpy(), | ||
"scale": sum(mask) + 1 - np.ones(1), | ||
} | ||
for name in sorted(params): | ||
logger.info("expected {} = {}".format(name, expected_grads[name])) | ||
logger.info("actual {} = {}".format(name, actual_grads[name])) | ||
assert_equal(actual_grads, expected_grads, prec=precision) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"reparameterized", [True, False], ids=["reparam", "nonreparam"] | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you point out exactly the scenario that is fixed by this one line change? IIRC,
score_function
would always be multiplied by another tensor that is masked, so the mask here would be redundant.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's again for consistency. When a variable is partially observed, some parameters are created for the missing observations, and some dummy parameters are created for the non-missing ones. When
has_rsample
is true, the dummy parameters are not updated: they retain their initial values because they do not contribute to the gradient. Whenhas_rsample
is false, the gradient "leaks" through this line and the dummy parameters are updated during learning (but I found that inference on the non-dummy parameters was correct in the cases I checked). As above, this line does not really fix any bug, it just tries to make the behavior consistent betweenhas_rsample = True
andhas_rsample = False
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining, I think I better understand now.