Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a factory for Poisson / NegativeBinomial / Binomial / BetaBinomial #2450

Merged
merged 6 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/contrib.epidemiology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ SEIR Models
.. automodule:: pyro.contrib.epidemiology.seir
:members:
:member-order: bysource

Distributions
-------------
.. automodule:: pyro.contrib.epidemiology.distributions
:members:
:member-order: bysource
2 changes: 2 additions & 0 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .seir import SimpleSEIRModel
from .sir import SimpleSIRModel

__all__ = [
"CompartmentalModel",
"SimpleSEIRModel",
"SimpleSIRModel",
"infection_dist",
]
98 changes: 98 additions & 0 deletions pyro/contrib/epidemiology/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch

import pyro.distributions as dist


def infection_dist(*,
individual_rate,
num_infectious,
num_susceptible=math.inf,
population=math.inf,
concentration=math.inf):
"""
Create a :class:`~pyro.distributions.Distribution` over the number of new
infections at a discrete time step.

This returns a Poisson, Negative-Binomial, Binomial, or Beta-Binomial
distribution depending on whether ``population`` and ``concentration`` are
finite. In Pyro models, the population is usually finite. In the limit
``population → ∞`` and ``num_susceptible/population → 1``, the Binomial
converges to Poisson and the Beta-Binomial converges to Negative-Binomial.
In the limit ``concentration → ∞``, the Negative-Binomial converges to
Poisson and the Beta-Binomial converges to Binomial.

The overdispersed distributions (Negative-Binomial and Beta-Binomial
returned when ``concentration < ∞``) are useful for modeling superspreader
individuals [2,3]. The finitely supported distributions Binomial and
fritzo marked this conversation as resolved.
Show resolved Hide resolved
Negative-Binomial are useful in small populations and in probabilistic
programming systems where truncation or censoring are expensive [3].

**References**

[1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
"Superspreading and the effect of individual variation on disease
emergence"
https://www.nature.com/articles/nature04153.pdf
[2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
"Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies
and Incidence Time Series"
https://academic.oup.com/mbe/article/34/11/2982/3952784
[3] Lawrence Murray et al. (2018)
"Delayed Sampling and Automatic Rao-Blackwellization of Probabilistic
Programs"
https://arxiv.org/pdf/1708.07787.pdf

:param individual_rate: The mean number of infections per infectious
individual per time step in the limit of large population, equal to
``R0 / tau`` where ``R0`` is the basic reproductive number and ``tau``
is the mean duration of infectiousness.
:param num_infectious: The number of infectious individuals at this
time step, sometimes ``I``, sometimes ``E+I``.
:param num_susceptible: The number ``S`` of susceptible individuals at this
time step. This defaults to an infinite population.
:param population: The total number of individuals in a population.
This defaults to an infinite population.
:concentration: The concentration or dispersion parameter ``k`` in
overdispersed models of superspreaders [2,3]. This defaults to minimum
fritzo marked this conversation as resolved.
Show resolved Hide resolved
variance ``concentration = ∞``.
"""
# Convert to colloquial variable names.
R = individual_rate
I = num_infectious
S = num_susceptible
N = population
k = concentration

if population == math.inf:
if k == math.inf:
# Return a Poisson distribution.
return dist.Poisson(R * I)
else:
# Return an overdispersed Negative-Binomial distribution.
combined_k = k * I
logits = torch.as_tensor(R / k).log()
return dist.NegativeBinomial(combined_k, logits=logits)
else:
# Compute the probability that any given (susceptible, infectious)
# pair of individuals results in an infection at this time step.
p = torch.as_tensor(R / N).clamp(max=1 - 1e-6)
# Combine infections from all individuals.
combined_p = p.neg().log1p().mul(I).expm1().neg() # = 1 - (1 - p)**I
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not consistent w.r.t. before. Previously, combined_p is 1 - (e^-rate_s) ** I = 1 - (e^-p) ** I.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, I am changing the formula. This PR seems to make it more mathematically plausible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I am not super confident in these formulas, but the plots do show they are at least in agreement. Let me know if you have any suggestions.

Copy link
Member

@fehiepsi fehiepsi Apr 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be similar to Reed-Frost model. I agree that this formula seems to be more reasonable when N is small. Two formulas are similar when N is large (when p is small: e^-p ~ 1-p).

combined_p = combined_p.clamp(min=1e-6)

if k == math.inf:
# Return a pure Binomial model, combining the independent Binomial
# models of each infectious individual.
return dist.ExtendedBinomial(S, combined_p)
else:
# Return an overdispersed Beta-Binomial model, combining
# independent BetaBinomial(c1,c0,S) models for each infectious
# individual.
c1 = k * I
c0 = c1 * (combined_p.reciprocal() - 1).clamp(min=1e-6)
return dist.ExtendedBetaBinomial(c1, c0, S)
33 changes: 16 additions & 17 deletions pyro/contrib/epidemiology/seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.ops.tensor_utils import convolve

from .compartmental import CompartmentalModel
from .distributions import infection_dist


class SimpleSEIRModel(CompartmentalModel):
Expand Down Expand Up @@ -79,29 +80,25 @@ def global_model(self):
tau_i = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
rho = pyro.sample("rho", dist.Uniform(0, 1))

# Convert interpretable parameters to distribution parameters.
rate_s = -R0 / (tau_i * self.population)
prob_e = 1 / tau_e
prob_i = 1 / tau_i

return rate_s, prob_e, prob_i, rho
return R0, tau_e, tau_i, rho

def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1}

def transition_fwd(self, params, state, t):
rate_s, prob_e, prob_i, rho = params
R0, tau_e, tau_i, rho = params

# Sample flows between compartments.
prob_s = -(rate_s * state["I"]).expm1()
S2E = pyro.sample("S2E_{}".format(t),
dist.Binomial(state["S"], prob_s))
infection_dist(individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population))
E2I = pyro.sample("E2I_{}".format(t),
dist.Binomial(state["E"], prob_e))
dist.Binomial(state["E"], 1 / tau_e))
I2R = pyro.sample("I2R_{}".format(t),
dist.Binomial(state["I"], prob_i))
dist.Binomial(state["I"], 1 / tau_i))

# Update compartments with flows.
state["S"] = state["S"] - S2E
Expand All @@ -114,23 +111,25 @@ def transition_fwd(self, params, state, t):
obs=self.data[t] if t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
rate_s, prob_e, prob_i, rho = params
R0, tau_e, tau_i, rho = params

# Reverse the flow computation.
S2E = prev["S"] - curr["S"]
E2I = prev["E"] - curr["E"] + S2E
I2R = prev["I"] - curr["I"] + E2I

# Condition on flows between compartments.
prob_s = -(rate_s * prev["I"]).expm1()
pyro.sample("S2E_{}".format(t),
dist.ExtendedBinomial(prev["S"], prob_s),
infection_dist(individual_rate=R0 / tau_i,
num_susceptible=prev["S"],
num_infectious=prev["I"],
population=self.population),
obs=S2E)
pyro.sample("E2I_{}".format(t),
dist.ExtendedBinomial(prev["E"], prob_e),
dist.ExtendedBinomial(prev["E"], 1 / tau_e),
obs=E2I)
pyro.sample("I2R_{}".format(t),
dist.ExtendedBinomial(prev["I"], prob_i),
dist.ExtendedBinomial(prev["I"], 1 / tau_i),
obs=I2R)

# Condition on observations.
Expand Down
28 changes: 14 additions & 14 deletions pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyro.ops.tensor_utils import convolve

from .compartmental import CompartmentalModel
from .distributions import infection_dist


class SimpleSIRModel(CompartmentalModel):
Expand Down Expand Up @@ -65,26 +66,23 @@ def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
rho = pyro.sample("rho", dist.Uniform(0, 1))

# Convert interpretable parameters to distribution parameters.
rate_s = -R0 / (tau * self.population)
prob_i = 1 / tau

return rate_s, prob_i, rho
return R0, tau, rho

def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "I": 1}

def transition_fwd(self, params, state, t):
rate_s, prob_i, rho = params
R0, tau, rho = params

# Sample flows between compartments.
prob_s = -(rate_s * state["I"]).expm1()
S2I = pyro.sample("S2I_{}".format(t),
dist.Binomial(state["S"], prob_s))
infection_dist(individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population))
I2R = pyro.sample("I2R_{}".format(t),
dist.Binomial(state["I"], prob_i))
dist.Binomial(state["I"], 1 / tau))

# Update compartments with flows.
state["S"] = state["S"] - S2I
Expand All @@ -96,19 +94,21 @@ def transition_fwd(self, params, state, t):
obs=self.data[t] if t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
rate_s, prob_i, rho = params
R0, tau, rho = params

# Reverse the flow computation.
S2I = prev["S"] - curr["S"]
I2R = prev["I"] - curr["I"] + S2I

# Condition on flows between compartments.
prob_s = -(rate_s * prev["I"]).expm1()
pyro.sample("S2I_{}".format(t),
dist.ExtendedBinomial(prev["S"], prob_s),
infection_dist(individual_rate=R0 / tau,
num_susceptible=prev["S"],
num_infectious=prev["I"],
population=self.population),
obs=S2I)
pyro.sample("I2R_{}".format(t),
dist.ExtendedBinomial(prev["I"], prob_i),
dist.ExtendedBinomial(prev["I"], 1 / tau),
obs=I2R)

# Condition on observations.
Expand Down
112 changes: 112 additions & 0 deletions tests/contrib/epidemiology/test_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import pyro.distributions as dist
from pyro.contrib.epidemiology import infection_dist

from tests.common import assert_close


def assert_dist_close(d1, d2):
x = torch.arange(float(200))
p1 = d1.log_prob(x).exp()
p2 = d2.log_prob(x).exp()

assert (p1.sum() - 1).abs() < 1e-3, "incomplete mass"
assert (p2.sum() - 1).abs() < 1e-3, "incomplete mass"

mean1 = (p1 * x).sum()
mean2 = (p2 * x).sum()
assert_close(mean1, mean2, rtol=0.05)

max_prob = torch.max(p1.max(), p2.max())
assert (p1 - p2).abs().max() / max_prob < 0.05


@pytest.mark.parametrize("R0,I", [
(1., 1),
(1., 10),
(10., 1),
(5., 5),
])
def test_binomial_vs_poisson(R0, I):
R0 = torch.tensor(R0)
I = torch.tensor(I)

d1 = infection_dist(individual_rate=R0, num_infectious=I)
d2 = infection_dist(individual_rate=R0, num_infectious=I,
num_susceptible=1000., population=1000.)

assert isinstance(d1, dist.Poisson)
assert isinstance(d2, dist.Binomial)
assert_dist_close(d1, d2)


@pytest.mark.parametrize("R0,I,k", [
(1., 1., 0.5),
(1., 1., 1.),
(1., 1., 2.),
(1., 10., 0.5),
(1., 10., 1.),
(1., 10., 2.),
(10., 1., 0.5),
(10., 1., 1.),
(10., 1., 2.),
(5., 5, 0.5),
(5., 5, 1.),
(5., 5, 2.),
])
def test_beta_binomial_vs_negative_binomial(R0, I, k):
R0 = torch.tensor(R0)
I = torch.tensor(I)

d1 = infection_dist(individual_rate=R0, num_infectious=I, concentration=k)
d2 = infection_dist(individual_rate=R0, num_infectious=I, concentration=k,
num_susceptible=1000., population=1000.)

assert isinstance(d1, dist.NegativeBinomial)
assert isinstance(d2, dist.BetaBinomial)
assert_dist_close(d1, d2)


@pytest.mark.parametrize("R0,I", [
(1., 1.),
(1., 10.),
(10., 1.),
(5., 5.),
])
def test_beta_binomial_vs_binomial(R0, I):
R0 = torch.tensor(R0)
I = torch.tensor(I)

d1 = infection_dist(individual_rate=R0, num_infectious=I,
num_susceptible=20., population=30.)
d2 = infection_dist(individual_rate=R0, num_infectious=I,
num_susceptible=20., population=30.,
concentration=200.)

assert isinstance(d1, dist.Binomial)
assert isinstance(d2, dist.BetaBinomial)
assert_dist_close(d1, d2)


@pytest.mark.parametrize("R0,I", [
(1., 1.),
(1., 10.),
(10., 1.),
(5., 5.),
])
def test_negative_binomial_vs_poisson(R0, I):
R0 = torch.tensor(R0)
I = torch.tensor(I)

d1 = infection_dist(individual_rate=R0, num_infectious=I)
d2 = infection_dist(individual_rate=R0, num_infectious=I,
concentration=200.)

assert isinstance(d1, dist.Poisson)
assert isinstance(d2, dist.NegativeBinomial)
assert_dist_close(d1, d2)
4 changes: 3 additions & 1 deletion tests/contrib/epidemiology/test_seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
{"dct": 1.},
{"num_quant_bins": 8},
], ids=str)
def test_smoke(duration, forecast, options):
def test_simple_smoke(duration, forecast, options):
population = 100
incubation_time = 2.0
recovery_time = 7.0
Expand All @@ -36,3 +36,5 @@ def test_smoke(duration, forecast, options):
# Predict and forecast.
samples = model.predict(forecast=forecast)
assert samples["S"].shape == (num_samples, duration + forecast)
assert samples["E"].shape == (num_samples, duration + forecast)
assert samples["I"].shape == (num_samples, duration + forecast)
Loading