Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inference with RenyiELBO for local latent variable models #3123

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ tutorial: FORCE

lint: FORCE
flake8
black --check *.py pyro examples tests scripts profiler
black --check .
isort --check .
python scripts/update_headers.py --check
mypy --install-types --non-interactive pyro scripts
Expand All @@ -28,7 +28,7 @@ license: FORCE
python scripts/update_headers.py

format: license FORCE
black *.py pyro examples tests scripts profiler
black .
isort .

version: FORCE
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[tool.black]
include = '''
(
pyro/.*\.py
| examples/.*\.py
| tests/.*\.py
| scripts/.*\.py
| profiler/.*\.py
)
'''
6 changes: 3 additions & 3 deletions pyro/infer/renyi_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import get_dependent_plate_dims, is_validation_enabled, torch_sum
from pyro.infer.util import get_nonparticle_plate_dims, is_validation_enabled, torch_sum
from pyro.util import check_if_enumerated, warn_if_nan


Expand Down Expand Up @@ -104,7 +104,7 @@ def loss(self, model, guide, *args, **kwargs):
# grab a vectorized trace from the generator
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = 0.0
sum_dims = get_dependent_plate_dims(model_trace.nodes.values())
sum_dims = get_nonparticle_plate_dims(model_trace.nodes.values())

# compute elbo
for name, site in model_trace.nodes.items():
Expand Down Expand Up @@ -152,7 +152,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = 0
surrogate_elbo_particle = 0
sum_dims = get_dependent_plate_dims(model_trace.nodes.values())
sum_dims = get_nonparticle_plate_dims(model_trace.nodes.values())

# compute elbo and surrogate elbo
for name, site in model_trace.nodes.items():
Expand Down
14 changes: 9 additions & 5 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,21 @@ def get_plate_stacks(trace):
}


def get_dependent_plate_dims(sites):
def get_nonparticle_plate_dims(sites):
"""
Return a list of unique dims for plates that are not common to all sites.
Return a list of unique dims of all plates except vectorized particles
"""
plate_sets = [
site["cond_indep_stack"] for site in sites if site["type"] == "sample"
]
all_plates = set().union(*plate_sets)
common_plates = all_plates.intersection(*plate_sets)
sum_plates = all_plates - common_plates
sum_dims = sorted({f.dim for f in sum_plates if f.dim is not None})
sum_dims = sorted(
{
f.dim
for f in all_plates
if f.dim is not None and f.name != "num_particles_vectorized"
}
)
return sum_dims


Expand Down
69 changes: 69 additions & 0 deletions tests/infer/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,3 +1005,72 @@ def guide(data, weights):
loss = svi.step(data, weights)
if step % 20 == 0:
logger.info("step {} loss = {:0.4g}".format(step, loss))


@pytest.mark.stage("integration", "integration_batch_2")
class OneWayNormalRandomEffects(TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much time do these tests take?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much time do these tests take?

In CI :

Time [sec]
test_renyi_nonreparameterized 50.67
test_renyi_reparameterized 36.59
test_renyi_vectorized 36.32

So all are in the top ten slowest for batch 2.

def setUp(self) -> None:
self.n_groups = 3
self.n_experiments = 5
self.data = torch.tensor(
[
[4.1, 3.5, 0.2, -3.3, 3.3],
[2.4, -6.5, -0.7, 4.4, -4.8],
[1.1, -0.6, 1.3, -1.3, -1.1],
]
)
self.group_locs = torch.tensor([[3.0], [-2.0], [0.0]])
self.group_prec = torch.tensor([[0.2], [0.1], [0.3]])
self.obs_prec = torch.tensor(6.0)
obs_prec = self.obs_prec * self.n_experiments
self.post_group_locs = (
self.data.mean(1, keepdim=True) * obs_prec
+ self.group_locs * self.group_prec
) / (obs_prec + self.group_prec)

def test_renyi_reparameterized(self):
self.do_elbo_test(True, 10_000, RenyiELBO(num_particles=2))

def test_renyi_vectorized(self):
self.do_elbo_test(
True,
15_000,
RenyiELBO(num_particles=2, vectorize_particles=True, max_plate_nesting=3),
)

def test_renyi_nonreparameterized(self):
self.do_elbo_test(False, 15000, RenyiELBO(alpha=0.2, num_particles=2))

def do_elbo_test(self, reparameterized, n_steps, loss, debug=False):
def model():
with pyro.plate("groups", self.n_groups, dim=-2):
group_loc = pyro.sample(
"group_loc",
dist.Normal(self.group_locs, torch.pow(self.group_prec, -0.5)),
)
with pyro.plate("data", self.n_experiments, dim=-1):
pyro.sample(
"y",
dist.Normal(group_loc, torch.pow(self.obs_prec, -0.5)),
obs=self.data,
)

def guide():
gloc = pyro.param(
"group_loc_param",
self.post_group_locs + torch.tensor([[0.05], [-0.08], [0.14]]),
)
with pyro.plate("groups", self.n_groups, dim=-2):
Normal = (
dist.Normal if reparameterized else fakes.NonreparameterizedNormal
)
pyro.sample("group_loc", Normal(gloc, torch.pow(self.group_prec, -0.5)))

adam = optim.Adam({"lr": 0.0005, "betas": (0.97, 0.999)})
svi = SVI(model, guide, adam, loss=loss)

for k in range(n_steps):
svi.step()

group_loc_error = param_abs_error("group_loc_param", self.post_group_locs)
assert_equal(0.0, group_loc_error, prec=0.08)