Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Effect handler that conditions a model on sample sites having the same value #3395

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

[tool.ruff]
extend-exclude = ["*.ipynb"]
line-length = 120


Expand Down
2 changes: 1 addition & 1 deletion pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def crps_empirical(pred, truth):


def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
"""
r"""
Computes negative Energy Score ES* (see equation 22 in [1]) between a
set of multivariate samples ``pred`` and a true data vector ``truth``. Running time
is quadratic in the number of samples ``n``. In case of univariate samples
Expand Down
33 changes: 30 additions & 3 deletions pyro/poutine/equalize_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,42 @@ class EqualizeMessenger(Messenger):

>>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')

Alternatively, the ``equalize`` messenger can be used to condition a model on primitive statements
having the same value by setting `keep_dist` to True. Consider the below model:

>>> def model():
... x = pyro.sample('x', pyro.distributions.Normal(0, 1))
... y = pyro.sample('y', pyro.distributions.Normal(5, 3))
... return x, y

The model can be conditioned on 'x' and 'y' having the same value by

>>> conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)

Note that the conditioned model defined above calculates the correct unnormalized log-probablity
density, but in order to correctly sample from it one must use SVI or MCMC techniques.

:param fn: a stochastic function (callable containing Pyro primitive calls)
:param sites: a string or list of strings to match site names (the strings can be regular expressions)
:param type: a string specifying the site type (default is 'sample')
:param bool keep_dist: Whether to keep the distributions of the second and subsequent
matching primitive statements. If kept this is equivalent to conditioning the model
on all matching primitive statements having the same value, as opposed to having the
second and subsequent matching primitive statements replaced by delta sampling functions.
Defaults to False.
:returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger`
"""

def __init__(
self, sites: Union[str, List[str]], type: Optional[str] = "sample"
self,
sites: Union[str, List[str]],
type: Optional[str] = "sample",
keep_dist: bool = False,
) -> None:
super().__init__()
self.sites = [sites] if isinstance(sites, str) else sites
self.type = type
self.keep_dist = keep_dist

def __enter__(self) -> Self:
self.value = None
Expand All @@ -72,6 +96,9 @@ def _process_message(self, msg: Message) -> None:
if self.value is not None and self._is_matching(msg): # type: ignore[unreachable]
msg["value"] = self.value # type: ignore[unreachable]
if msg["type"] == "sample":
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False)
msg["infer"] = {"_deterministic": True}
msg["is_observed"] = True
if not self.keep_dist:
msg["infer"] = {"_deterministic": True}
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(
False
)
44 changes: 44 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,50 @@ def test_render_model(self):
pyro.render_model(model)


@pytest.mark.parametrize("keep_dist", [False, True])
@pytest.mark.parametrize(
"loc_x, scale_x, loc_y, scale_y", [(0.0, 1.0, 5.0, 2.0), (5.0, 2.0, 0.0, 1.0)]
)
def test_condition_by_equalize(loc_x, scale_x, loc_y, scale_y, keep_dist):
# Create model and equalize it.
def model():
x = pyro.sample("x", dist.Normal(loc_x, scale_x))
y = pyro.sample("y", dist.Normal(loc_y, scale_y))
return x, y

equalized_model = pyro.poutine.equalize(model, ["x", "y"], keep_dist=keep_dist)

# Fit guide to model
guide = pyro.infer.autoguide.AutoNormal(equalized_model)
optim = pyro.optim.Adam(dict(lr=0.1))
svi = pyro.infer.SVI(
equalized_model,
guide,
optim,
loss=pyro.infer.TraceGraph_ELBO(num_particles=1000, vectorize_particles=True),
)
for step_num in range(100):
svi.step()

# Get guide distribution parameters
loc, scale = guide._get_loc_and_scale("x")
loc = float(loc.detach().numpy())
scale = float(scale.detach().numpy())

# Verify against expected distribution parameters
if keep_dist:
# Both 'x' and 'y' are sampled and the model is conditioned on 'x' and 'y' having the same value.
expected_var = 1 / (1 / scale_x**2 + 1 / scale_y**2)
expected_loc = (loc_x / scale_x**2 + loc_y / scale_y**2) * expected_var
expected_scale = expected_var**0.5
else:
# The random variable 'x' is sampled and its value is assigned to 'y'.
expected_loc = loc_x
expected_scale = scale_x
assert_close(loc, expected_loc, atol=0.05)
assert_close(scale, expected_scale, atol=0.05)


@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
@pytest.mark.parametrize("depth", [0, 1, 2])
def test_enumerate_poutine(depth, first_available_dim):
Expand Down
Loading