Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc fixes to contrib.epidemiology #2527

Merged
merged 4 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def generate_data(args):


def infer(args, model):
parallel = args.num_chains > 1
energies = []

def hook_fn(kernel, *unused):
Expand All @@ -91,16 +92,18 @@ def hook_fn(kernel, *unused):
heuristic_ess_threshold=args.ess_threshold,
warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
num_chains=args.num_chains,
mp_context="spawn" if parallel else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately i've found this to be somewhat buggy (fair number of crashes)

Copy link
Member Author

@fritzo fritzo Jun 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you use "forkserver" instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no i've used spawn. but i've tried to avoid it...

max_tree_depth=args.max_tree_depth,
arrowhead_mass=args.arrowhead_mass,
num_quant_bins=args.num_bins,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
jit_compile=args.jit,
hook_fn=hook_fn)
hook_fn=None if parallel else hook_fn)

mcmc.summary()
if args.plot:
if args.plot and energies:
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 3))
plt.plot(energies)
Expand Down Expand Up @@ -293,6 +296,7 @@ def main(args):
parser.add_argument("-np", "--num-particles", default=1024, type=int)
parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float)
parser.add_argument("-w", "--warmup-steps", type=int)
parser.add_argument("-c", "--num-chains", default=1, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-a", "--arrowhead-mass", action="store_true")
parser.add_argument("-r", "--rng-seed", default=0, type=int)
Expand Down
45 changes: 25 additions & 20 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
import logging
import operator
import re
Expand Down Expand Up @@ -289,7 +290,7 @@ def generate(self, fixed={}):
for name, site in trace.nodes.items()
if site["type"] == "sample")

self._concat_series(samples)
self._concat_series(samples, trace)
return samples

@set_approx_log_prob_tol(0.1)
Expand Down Expand Up @@ -365,20 +366,8 @@ def fit(self, **options):
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options)
if k.startswith("heuristic_")}

def heuristic():
with poutine.block():
init_values = self.heuristic(**heuristic_options)
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if haar:
haar.user_to_aux(init_values)
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v.item())
for k, v in init_values.items()
if v.numel() == 1)))
return init_to_value(values=init_values)
init_strategy = init_to_generated(
generate=functools.partial(self._heuristic, haar, **heuristic_options))

# Configure a kernel.
logger.info("Running inference...")
Expand All @@ -387,7 +376,7 @@ def heuristic():
model = haar.reparam(model)
kernel = NUTS(model,
full_mass=full_mass,
init_strategy=init_to_generated(generate=heuristic),
init_strategy=init_strategy,
max_plate_nesting=self.max_plate_nesting,
jit_compile=options.pop("jit_compile", False),
jit_options=options.pop("jit_options", None),
Expand Down Expand Up @@ -464,7 +453,7 @@ def predict(self, forecast=0):
for name, site in trace.nodes.items()
if site["type"] == "sample")

self._concat_series(samples, forecast, vectorized=True)
self._concat_series(samples, trace, forecast)
return samples

@torch.no_grad()
Expand Down Expand Up @@ -516,12 +505,27 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):

# Internal helpers ########################################

def _concat_series(self, samples, forecast=0, vectorized=False):
def _heuristic(self, haar, **options):
with poutine.block():
init_values = self.heuristic(**options)
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if haar:
haar.user_to_aux(init_values)
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v.item())
for k, v in init_values.items()
if v.numel() == 1)))
return init_to_value(values=init_values)

def _concat_series(self, samples, trace, forecast=0):
"""
Concatenate sequential time series into tensors, in-place.

:param dict samples: A dictionary of samples.
"""
time_dim = -2 if self.is_regional else -1
for name in set(self.compartments).union(self.series):
pattern = name + "_[0-9]+"
series = []
Expand All @@ -532,8 +536,9 @@ def _concat_series(self, samples, forecast=0, vectorized=False):
continue
assert len(series) == self.duration + forecast
series = torch.broadcast_tensors(*map(torch.as_tensor, series))
if vectorized and name != "obs": # TODO Generalize.
samples[name] = torch.cat(series, dim=1)
dim = time_dim - trace.nodes[name + "_0"]["fn"].event_dim
if series[0].dim() >= -dim:
samples[name] = torch.cat(series, dim=dim)
else:
samples[name] = torch.stack(series)

Expand Down
9 changes: 5 additions & 4 deletions pyro/contrib/epidemiology/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyro.distributions.util import is_validation_enabled

_RELAX = False
_RELAX_MIN_VARIANCE = 0.1


def _all(x):
Expand Down Expand Up @@ -98,7 +99,7 @@ def _validate_overdispersion(overdispersion):
raise ValueError("Expected overdispersion < 2")


def _relaxed_binomial(total_count, probs, *, min_variance=0.25):
def _relaxed_binomial(total_count, probs):
"""
Returns a moment-matched :class:`~pyro.distributions.Normal` approximating
a :class:`~pyro.distributions.Binomial` but allowing arbitrary real
Expand All @@ -109,11 +110,11 @@ def _relaxed_binomial(total_count, probs, *, min_variance=0.25):
mean = probs * total_count
variance = total_count * probs * (1 - probs)

scale = variance.clamp(min=min_variance).sqrt()
scale = variance.clamp(min=_RELAX_MIN_VARIANCE).sqrt()
return dist.Normal(mean, scale)


def _relaxed_beta_binomial(concentration1, concentration0, total_count, *, min_variance=0.25):
def _relaxed_beta_binomial(concentration1, concentration0, total_count):
"""
Returns a moment-matched :class:`~pyro.distributions.Normal` approximating
a :class:`~pyro.distributions.BetaBinomial` but allowing arbitrary real
Expand All @@ -128,7 +129,7 @@ def _relaxed_beta_binomial(concentration1, concentration0, total_count, *, min_v
mean = beta_mean * total_count
variance = beta_variance * total_count * (c + total_count)

scale = variance.clamp(min=min_variance).sqrt()
scale = variance.clamp(min=_RELAX_MIN_VARIANCE).sqrt()
return dist.Normal(mean, scale)


Expand Down
6 changes: 5 additions & 1 deletion pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,12 @@ def model(data):
# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
# Try to initialize kernel.transforms using kernel.setup().
if getattr(self.kernel, "transforms", None) is None:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
# Use `kernel.transforms` when available
if hasattr(self.kernel, 'transforms') and self.kernel.transforms is not None:
if getattr(self.kernel, "transforms", None) is not None:
Comment on lines +411 to +416
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required because manually calling initialize_model() below would ignore init_strategy and attempt to sample from ImproperUniform. Instead we let kernel.setup() call initialize_model() with proper arguments.

self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
Expand Down
10 changes: 5 additions & 5 deletions tests/contrib/epidemiology/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pyro.distributions as dist
from pyro.contrib.epidemiology import beta_binomial_dist, binomial_dist, infection_dist
from pyro.contrib.epidemiology.distributions import set_relaxed_distributions
from pyro.contrib.epidemiology.distributions import _RELAX_MIN_VARIANCE, set_relaxed_distributions
from tests.common import assert_close


Expand Down Expand Up @@ -199,7 +199,7 @@ def test_relaxed_binomial():
d2 = binomial_dist(total_count, probs)
assert isinstance(d2, dist.Normal)
assert_close(d2.mean, d1.mean)
assert_close(d2.variance, d1.variance.clamp(min=0.25))
assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))


@pytest.mark.parametrize("overdispersion", [0.05, 0.1, 0.2, 0.5, 1.0])
Expand All @@ -214,7 +214,7 @@ def test_relaxed_overdispersed_binomial(overdispersion):
d2 = binomial_dist(total_count, probs, overdispersion=overdispersion)
assert isinstance(d2, dist.Normal)
assert_close(d2.mean, d1.mean)
assert_close(d2.variance, d1.variance.clamp(min=0.25))
assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))


def test_relaxed_beta_binomial():
Expand All @@ -229,7 +229,7 @@ def test_relaxed_beta_binomial():
d2 = beta_binomial_dist(concentration1, concentration0, total_count)
assert isinstance(d2, dist.Normal)
assert_close(d2.mean, d1.mean)
assert_close(d2.variance, d1.variance.clamp(min=0.25))
assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))


@pytest.mark.parametrize("overdispersion", [0.05, 0.1, 0.2, 0.5, 1.0])
Expand All @@ -247,4 +247,4 @@ def test_relaxed_overdispersed_beta_binomial(overdispersion):
overdispersion=overdispersion)
assert isinstance(d2, dist.Normal)
assert_close(d2.mean, d1.mean)
assert_close(d2.variance, d1.variance.clamp(min=0.25))
assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))
4 changes: 4 additions & 0 deletions tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
{"jit_compile": True},
{"jit_compile": True, "haar_full_mass": 2},
{"jit_compile": True, "num_quant_bins": 2},
{"num_chains": 2, "mp_context": "spawn"},
{"num_chains": 2, "mp_context": "spawn", "num_quant_bins": 2},
{"num_chains": 2, "mp_context": "spawn", "jit_compile": True},
], ids=str)
def test_simple_sir_smoke(duration, forecast, options):
population = 100
Expand All @@ -54,6 +57,7 @@ def test_simple_sir_smoke(duration, forecast, options):

# Predict and forecast.
samples = model.predict(forecast=forecast)
num_samples *= options.get("num_chains", 1)
assert samples["S"].shape == (num_samples, duration + forecast)
assert samples["I"].shape == (num_samples, duration + forecast)

Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'contrib/autoname/tree_data.py --num-epochs=1',
'contrib/cevae/synthetic.py --num-epochs=1',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -c=2',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1',
Expand Down