Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable AutoNormal,AutoMultivariateNormal in contrib.epidemiology #2538

Merged
merged 1 commit into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import pyro.poutine as poutine
from pyro.distributions.transforms import HaarTransform
from pyro.infer import MCMC, NUTS, SVI, JitTrace_ELBO, SMCFilter, Trace_ELBO, infer_discrete
from pyro.infer.autoguide import AutoLowRankMultivariateNormal, init_to_generated, init_to_value
from pyro.infer.autoguide import (AutoLowRankMultivariateNormal, AutoMultivariateNormal, AutoNormal, init_to_generated,
init_to_value)
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
Expand Down Expand Up @@ -82,6 +83,22 @@ def transition(self, params, state, t): ...
effect = samples2["my_result"].mean() - samples1["my_result"].mean()
print("average effect = {:0.3g}".format(effect))

An example workflow is to use cheaper approximate inference while finding
good model structure and priors, then move to more accurate but more
expensive inference once the model is plausible.

1. Start with ``.fit_svi(guide_rank=1, num_steps=2000)`` for cheap
inference while you search for a good model.
2. Additionally infer long-range correlations by moving to a low-rank
multivariate normal guide via ``.fit_svi(guide_rank=None,
num_steps=5000)``.
3. Optionally additionally infer non-Gaussian posterior by moving to the
more expensive (but still approximate via moment matching)
``.fit_mcmc(num_quant_bins=1, num_samples=10000, num_chains=2)``.
4. Optionally improve fit around small counts by moving the the more
expensive enumeration-based algorithm ``.fit_mcmc(num_quant_bins=4,
num_samples=10000, num_chains=2)`` (GPU recommended).

:ivar dict samples: Dictionary of posterior samples.
:param list compartments: A list of strings of compartment names.
:param int duration: The number of discrete time steps in this model.
Expand Down Expand Up @@ -298,24 +315,31 @@ def generate(self, fixed={}):

def fit_svi(self, *,
num_samples=100,
num_steps=5000,
num_steps=2000,
num_particles=32,
learning_rate=0.1,
learning_rate_decay=0.01,
betas=(0.8, 0.99),
haar=True,
init_scale=0.1,
guide_rank=None,
init_scale=0.01,
guide_rank=0,
jit=False,
log_every=200,
**options):
"""
Runs stochastic variational inference to generate posterior samples.

This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples``
attribute on completion.

This approximate inference method is useful for quickly iterating on
probabilistic models.

:param int num_samples: Number of posterior samples to draw from the
trained guide. Defaults to 100.
:param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps.
:param int num_particles: Number of :class:`~pyro.infer.svi.SVI` particles per step.
:param int num_particles: Number of :class:`~pyro.infer.svi.SVI`
particles per step.
:param int learning_rate: Learning rate for the
:class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
:param int learning_rate_decay: Learning rate for the
Expand All @@ -324,8 +348,13 @@ def fit_svi(self, *,
:param tuple betas: Momentum parameters for the
:class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
:param bool haar: Whether to use a Haar wavelet reparameterizer.
:param int guide_rank: Rank of the
:param int guide_rank: Rank of the auto normal guide. If zero (default)
use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a
positive integer or None, use an
:class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
If the string "full", use an
:class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These
latter two require more ``num_steps`` to fit.
:param float init_scale: Initial scale of the
:class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
:param bool jit: Whether to use a jit compiled ELBO.
Expand Down Expand Up @@ -361,8 +390,16 @@ def fit_svi(self, *,
model = self._relaxed_model
if haar:
model = haar.reparam(model)
guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
init_scale=init_scale, rank=guide_rank)
if guide_rank == 0:
guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale)
elif guide_rank == "full":
guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy,
init_scale=init_scale)
elif guide_rank is None or isinstance(guide_rank, int):
guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
init_scale=init_scale, rank=guide_rank)
else:
raise ValueError("Invalid guide_rank: {}".format(guide_rank))
Elbo = JitTrace_ELBO if jit else Trace_ELBO
elbo = Elbo(max_plate_nesting=self.max_plate_nesting,
num_particles=num_particles, vectorize_particles=True,
Expand Down Expand Up @@ -410,6 +447,10 @@ def fit_mcmc(self, **options):
:class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
attribute on completion.

This uses an asymptotically exact enumeration-based model when
``num_quant_bins > 1``, and a cheaper moment-matched approximate model
when ``num_quant_bins == 1``.

:param \*\*options: Options passed to
:class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
pulled out and have special meaning.
Expand Down Expand Up @@ -639,7 +680,7 @@ def _heuristic(self, haar, **options):
haar.user_to_aux(init_values)
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v.item())
for k, v in init_values.items()
for k, v in sorted(init_values.items())
if v.numel() == 1)))
return init_to_value(values=init_values)

Expand Down
3 changes: 3 additions & 0 deletions tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
@pytest.mark.parametrize("algo,options", [
("svi", {}),
("svi", {"haar": False}),
("svi", {"guide_rank": None}),
("svi", {"guide_rank": 2}),
("svi", {"guide_rank": "full"}),
("mcmc", {}),
("mcmc", {"haar": True}),
("mcmc", {"haar_full_mass": 2}),
Expand Down