Skip to content

Commit

Permalink
Tweak defaults for CompartmentalModel.fit_mcmc() (#2563)
Browse files Browse the repository at this point in the history
* Tweak defaults for CompartmentalModel.fit_mcmc()

* Simplify tutorial
  • Loading branch information
fritzo authored Jul 16, 2020
1 parent b481e9a commit 924f5e2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 28 deletions.
11 changes: 6 additions & 5 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,10 @@ def fit_mcmc(self, **options):
Note that computational cost is exponential in `num_quant_bins`.
Defaults to 1 for relaxed inference.
:param bool haar: Whether to use a Haar wavelet reparameterizer.
Defaults to True.
:param int haar_full_mass: Number of low frequency Haar components to
include in the full mass matrix. If nonzero this implies
``haar=True``.
include in the full mass matrix. If ``haar=False`` then this is
ignored. Defaults to 10.
:param int heuristic_num_particles: Passed to :meth:`heuristic` as
``num_particles``. Defaults to 1024.
:returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
Expand All @@ -489,14 +490,14 @@ def fit_mcmc(self, **options):

# Setup Haar wavelet transform.
haar = options.pop("haar", False)
haar_full_mass = options.pop("haar_full_mass", 0)
haar_full_mass = options.pop("haar_full_mass", 10)
full_mass = options.pop("full_mass", self.full_mass)
assert isinstance(haar, bool)
assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
assert isinstance(full_mass, (bool, list))
haar_full_mass = min(haar_full_mass, self.duration)
if haar_full_mass:
haar = True
if not haar:
haar_full_mass = 0
if full_mass is True:
haar_full_mass = 0 # No need to split.
elif haar_full_mass >= self.duration:
Expand Down
45 changes: 23 additions & 22 deletions tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@
("svi", {"guide_rank": 2}),
("svi", {"guide_rank": "full"}),
("mcmc", {}),
("mcmc", {"haar": True}),
("mcmc", {"haar": False}),
("mcmc", {"haar_full_mass": 0}),
("mcmc", {"haar_full_mass": 2}),
("mcmc", {"num_quant_bins": 2}),
("mcmc", {"num_quant_bins": 4}),
("mcmc", {"num_quant_bins": 8}),
("mcmc", {"num_quant_bins": 12}),
("mcmc", {"num_quant_bins": 16}),
("mcmc", {"num_quant_bins": 2, "haar": True}),
("mcmc", {"num_quant_bins": 2, "haar": False}),
("mcmc", {"arrowhead_mass": True}),
("mcmc", {"jit_compile": True}),
("mcmc", {"jit_compile": True, "haar_full_mass": 2}),
("mcmc", {"jit_compile": True, "haar_full_mass": 0}),
("mcmc", {"jit_compile": True, "num_quant_bins": 2}),
("mcmc", {"num_chains": 2, "mp_context": "spawn"}),
("mcmc", {"num_chains": 2, "mp_context": "spawn", "num_quant_bins": 2}),
Expand Down Expand Up @@ -77,8 +78,8 @@ def test_simple_sir_smoke(duration, forecast, options, algo):
("svi", {}),
("svi", {"haar": False}),
("mcmc", {}),
("mcmc", {"haar": True}),
("mcmc", {"haar_full_mass": 2}),
("mcmc", {"haar": False}),
("mcmc", {"haar_full_mass": 0}),
("mcmc", {"num_quant_bins": 2}),
], ids=str)
def test_simple_seir_smoke(duration, forecast, options, algo):
Expand Down Expand Up @@ -116,7 +117,7 @@ def test_simple_seir_smoke(duration, forecast, options, algo):
@pytest.mark.parametrize("algo,options", [
("svi", {}),
("mcmc", {}),
("mcmc", {"haar_full_mass": 2}),
("mcmc", {"haar_full_mass": 0}),
], ids=str)
def test_simple_seird_smoke(duration, forecast, options, algo):
population = 100
Expand Down Expand Up @@ -155,8 +156,8 @@ def test_simple_seird_smoke(duration, forecast, options, algo):
@pytest.mark.parametrize("forecast", [7])
@pytest.mark.parametrize("options", [
{},
{"haar": False},
{"num_quant_bins": 2},
{"haar_full_mass": 2},
], ids=str)
def test_overdispersed_sir_smoke(duration, forecast, options):
population = 100
Expand Down Expand Up @@ -186,7 +187,7 @@ def test_overdispersed_sir_smoke(duration, forecast, options):
@pytest.mark.parametrize("forecast", [7])
@pytest.mark.parametrize("options", [
{},
{"haar_full_mass": 2},
{"haar": False},
{"num_quant_bins": 2},
], ids=str)
def test_overdispersed_seir_smoke(duration, forecast, options):
Expand Down Expand Up @@ -221,8 +222,8 @@ def test_overdispersed_seir_smoke(duration, forecast, options):
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("options", [
{},
{"haar": True},
{"haar_full_mass": 2},
{"haar": False},
{"haar_full_mass": 0},
{"num_quant_bins": 2},
], ids=str)
def test_superspreading_sir_smoke(duration, forecast, options):
Expand Down Expand Up @@ -253,8 +254,8 @@ def test_superspreading_sir_smoke(duration, forecast, options):
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("options", [
{},
{"haar": True},
{"haar_full_mass": 2},
{"haar": False},
{"haar_full_mass": 0},
{"num_quant_bins": 2},
], ids=str)
def test_superspreading_seir_smoke(duration, forecast, options):
Expand Down Expand Up @@ -335,7 +336,7 @@ def test_coalescent_likelihood_smoke(duration, forecast, options, algo):
("svi", {}),
("svi", {"haar": False}),
("mcmc", {}),
("mcmc", {"haar_full_mass": 2}),
("mcmc", {"haar": False}),
("mcmc", {"num_quant_bins": 2}),
], ids=str)
def test_heterogeneous_sir_smoke(duration, forecast, options, algo):
Expand Down Expand Up @@ -368,8 +369,8 @@ def test_heterogeneous_sir_smoke(duration, forecast, options, algo):
@pytest.mark.parametrize("options", [
xfail_param({}, reason="Delta is incompatible with relaxed inference"),
{"num_quant_bins": 2},
{"num_quant_bins": 2, "haar": True},
{"num_quant_bins": 2, "haar_full_mass": 3},
{"num_quant_bins": 2, "haar": False},
{"num_quant_bins": 2, "haar_full_mass": 0},
{"num_quant_bins": 4},
], ids=str)
def test_sparse_smoke(duration, forecast, options):
Expand Down Expand Up @@ -411,8 +412,8 @@ def test_sparse_smoke(duration, forecast, options):
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("options", [
{},
{"haar": True},
{"haar_full_mass": 4},
{"haar": False},
{"haar_full_mass": 0},
{"num_quant_bins": 2},
], ids=str)
def test_unknown_start_smoke(duration, pre_obs_window, forecast, options):
Expand Down Expand Up @@ -459,8 +460,8 @@ def test_unknown_start_smoke(duration, pre_obs_window, forecast, options):
("svi", {}),
("svi", {"haar": False}),
("mcmc", {}),
("mcmc", {"haar": True}),
("mcmc", {"haar_full_mass": 2}),
("mcmc", {"haar": False}),
("mcmc", {"haar_full_mass": 0}),
("mcmc", {"num_quant_bins": 2}),
], ids=str)
def test_regional_smoke(duration, forecast, options, algo):
Expand Down Expand Up @@ -500,11 +501,11 @@ def test_regional_smoke(duration, forecast, options, algo):
("svi", {}),
("svi", {"haar": False}),
("mcmc", {}),
("mcmc", {"haar": True}),
("mcmc", {"haar_full_mass": 2}),
("mcmc", {"haar": False}),
("mcmc", {"haar_full_mass": 0}),
("mcmc", {"num_quant_bins": 2}),
("mcmc", {"jit_compile": True}),
("mcmc", {"jit_compile": True, "haar_full_mass": 2}),
("mcmc", {"jit_compile": True, "haar": False}),
("mcmc", {"jit_compile": True, "num_quant_bins": 2}),
], ids=str)
def test_hetero_regional_smoke(duration, forecast, options, algo):
Expand Down
2 changes: 1 addition & 1 deletion tutorial/source/epi_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@
"%%time\n",
"model = SimpleSIRModel(population, recovery_time, obs)\n",
"mcmc = model.fit_mcmc(num_samples=4 if smoke_test else 400,\n",
" haar_full_mass=10, jit_compile=True)"
" jit_compile=True)"
]
},
{
Expand Down

0 comments on commit 924f5e2

Please sign in to comment.