Skip to content

Commit

Permalink
Reorganize contrib.epidemiology models (#2499)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored May 22, 2020
1 parent d3510f8 commit 99222cf
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 428 deletions.
12 changes: 3 additions & 9 deletions docs/source/contrib.epidemiology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,9 @@ Base Compartmental Model
:members:
:member-order: bysource

SIR Models
----------
.. automodule:: pyro.contrib.epidemiology.sir
:members:
:member-order: bysource

SEIR Models
-----------
.. automodule:: pyro.contrib.epidemiology.seir
Example Models
--------------
.. automodule:: pyro.contrib.epidemiology.models
:members:
:member-order: bysource

Expand Down
8 changes: 4 additions & 4 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions import biject_to, constraints

import pyro
from pyro.contrib.epidemiology import OverdispersedSEIRModel, OverdispersedSIRModel, SimpleSEIRModel, SimpleSIRModel
from pyro.contrib.epidemiology import SimpleSEIRModel, SimpleSIRModel, SuperspreadingSEIRModel, SuperspreadingSIRModel

logging.basicConfig(format='%(message)s', level=logging.INFO)

Expand All @@ -26,13 +26,13 @@ def Model(args, data):
return SimpleSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
else:
return OverdispersedSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
return SuperspreadingSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
else:
if args.concentration == math.inf:
return SimpleSIRModel(args.population, args.recovery_time, data)
else:
return OverdispersedSIRModel(args.population, args.recovery_time, data)
return SuperspreadingSIRModel(args.population, args.recovery_time, data)


def generate_data(args):
Expand Down
8 changes: 4 additions & 4 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .seir import OverdispersedSEIRModel, SimpleSEIRModel
from .sir import OverdispersedSIRModel, RegionalSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel
from .models import (RegionalSIRModel, SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, SuperspreadingSEIRModel,
SuperspreadingSIRModel, UnknownStartSIRModel)

__all__ = [
"CompartmentalModel",
"OverdispersedSEIRModel",
"OverdispersedSIRModel",
"RegionalSIRModel",
"SimpleSEIRModel",
"SimpleSIRModel",
"SparseSIRModel",
"SuperspreadingSEIRModel",
"SuperspreadingSIRModel",
"UnknownStartSIRModel",
"infection_dist",
]
285 changes: 261 additions & 24 deletions pyro/contrib/epidemiology/sir.py → pyro/contrib/epidemiology/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,114 @@ def transition_bwd(self, params, prev, curr, t):
obs=self.data[t])


class OverdispersedSIRModel(CompartmentalModel):
class SimpleSEIRModel(CompartmentalModel):
"""
Overdispersed Susceptible-Infected-Recovered model.
Susceptible-Exposed-Infected-Recovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with three
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.
This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "E" for exposed, "I" for infected,
and "R" for recovered individuals (the recovered individuals are
implicit: ``R = population - S - E - I``) with transitions
``S -> E -> I -> R``.
:param int population: Total ``population = S + E + I + R``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""

def __init__(self, population, incubation_time, recovery_time, data):
compartments = ("S", "E", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)

assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time

assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time

self.data = data

series = ("S2E", "E2I", "I2R", "obs")
full_mass = [("R0", "rho")]

def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
rho = pyro.sample("rho", dist.Beta(2, 2))
return R0, tau_e, tau_i, rho

def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1}

def transition_fwd(self, params, state, t):
R0, tau_e, tau_i, rho = params

# Sample flows between compartments.
S2E = pyro.sample("S2E_{}".format(t),
infection_dist(individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population))
E2I = pyro.sample("E2I_{}".format(t),
dist.Binomial(state["E"], 1 / tau_e))
I2R = pyro.sample("I2R_{}".format(t),
dist.Binomial(state["I"], 1 / tau_i))

# Update compartments with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R

# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2E, rho),
obs=self.data[t] if t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
R0, tau_e, tau_i, rho = params

# Reverse the flow computation.
S2E = prev["S"] - curr["S"]
E2I = prev["E"] - curr["E"] + S2E
I2R = prev["I"] - curr["I"] + E2I

# Condition on flows between compartments.
pyro.sample("S2E_{}".format(t),
infection_dist(individual_rate=R0 / tau_i,
num_susceptible=prev["S"],
num_infectious=prev["I"],
population=self.population),
obs=S2E)
pyro.sample("E2I_{}".format(t),
dist.ExtendedBinomial(prev["E"], 1 / tau_e),
obs=E2I)
pyro.sample("I2R_{}".format(t),
dist.ExtendedBinomial(prev["I"], 1 / tau_i),
obs=I2R)

# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2E, rho),
obs=self.data[t])


class SuperspreadingSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` by adding superspreading effects.
To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual
reproductive number) by assuming each infected individual infects
Expand Down Expand Up @@ -215,17 +313,164 @@ def transition_bwd(self, params, prev, curr, t):
obs=self.data[t])


class SuperspreadingSEIRModel(CompartmentalModel):
r"""
Generalizes :class:`SimpleSEIRModel` by adding superspreading effects.
To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual
reproductive number) by assuming each infected individual infects
BetaBinomial-many susceptible individuals, where the BetaBinomial
distribution acts as an overdispersed Binomial distribution, adapting the
more standard NegativeBinomial distribution that acts as an overdispersed
Poisson distribution [1,2] to the setting of finite populations. To
preserve Markov structure, we follow [2] and assume all infections by a
single individual occur on the single time step where that individual makes
an ``I -> R`` transition. That is, whereas the :class:`SimpleSEIRModel`
assumes infected individuals infect `Binomial(S,R/tau)`-many susceptible
individuals during each infected time step (over `tau`-many steps on
average), this model assumes they infect `BetaBinomial(k,...,S)`-many
susceptible individuals but only on the final time step before recovering.
This model also adds an optional likelihood for observed phylogenetic data
in the form of coalescent times. These are provided as a pair
``(leaf_times, coal_times)`` of times at which genomes are sequenced and
lineages coalesce, respectively. We incorporate this data using the
:class:`~pyro.distributions.CoalescentRateLikelihood` with base coalescence
rate computed from the ``S`` and ``I`` populations. This likelihood is
independent across time and preserves the Markov propert needed for
inference.
**References**
[1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
"Superspreading and the effect of individual variation on disease
emergence"
https://www.nature.com/articles/nature04153.pdf
[2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
"Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies
and Incidence Time Series"
https://academic.oup.com/mbe/article/34/11/2982/3952784
:param int population: Total ``population = S + E + I + R``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""

def __init__(self, population, incubation_time, recovery_time, data, *,
leaf_times=None, coal_times=None):
compartments = ("S", "E", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)

assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time

assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time

self.data = data

assert (leaf_times is None) == (coal_times is None)
if leaf_times is None:
self.coal_likelihood = None
else:
self.coal_likelihood = dist.CoalescentRateLikelihood(
leaf_times, coal_times, duration)

series = ("S2E", "E2I", "I2R", "obs")
full_mass = [("R0", "rho", "k")]

def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
k = pyro.sample("k", dist.Exponential(1.))
rho = pyro.sample("rho", dist.Beta(2, 2))
return R0, k, tau_e, tau_i, rho

def initialize(self, params):
# Start with a single exposure.
return {"S": self.population - 1, "E": 0, "I": 1}

def transition_fwd(self, params, state, t):
R0, k, tau_e, tau_i, rho = params

# Sample flows between compartments.
E2I = pyro.sample("E2I_{}".format(t),
dist.Binomial(state["E"], 1 / tau_e))
I2R = pyro.sample("I2R_{}".format(t),
dist.Binomial(state["I"], 1 / tau_i))
S2E = pyro.sample("S2E_{}".format(t),
infection_dist(individual_rate=R0,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
concentration=k))

# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2E, rho),
obs=self.data[t] if t < self.duration else None)
if self.coal_likelihood is not None and t < self.duration:
R = R0 * state["S"] / self.population
coal_rate = R * (1. + 1. / k) / (tau_i * state["I"] + 1e-8)
pyro.factor("coalescent_{}".format(t),
self.coal_likelihood(coal_rate, t))

# Update compartements with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R

def transition_bwd(self, params, prev, curr, t):
R0, k, tau_e, tau_i, rho = params

# Reverse the flow computation.
S2E = prev["S"] - curr["S"]
E2I = prev["E"] - curr["E"] + S2E
I2R = prev["I"] - curr["I"] + E2I

# Condition on flows between compartments.
pyro.sample("S2E_{}".format(t),
infection_dist(individual_rate=R0,
num_susceptible=prev["S"],
num_infectious=prev["I"],
population=self.population,
concentration=k),
obs=S2E)
pyro.sample("E2I_{}".format(t),
dist.ExtendedBinomial(prev["E"], 1 / tau_e),
obs=E2I)
pyro.sample("I2R_{}".format(t),
dist.ExtendedBinomial(prev["I"], 1 / tau_i),
obs=I2R)

# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2E, rho),
obs=self.data[t])
if self.coal_likelihood is not None:
R = R0 * prev["S"] / self.population
coal_rate = R * (1. + 1. / k) / (tau_i * prev["I"] + 1e-8)
pyro.factor("coalescent_{}".format(t),
self.coal_likelihood(coal_rate, t))


class SparseSIRModel(CompartmentalModel):
"""
Susceptible-Infected-Recovered model with sparsely observed infections.
Generalizes :class:`SimpleSIRModel` to allow sparsely observed infections.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.
This model allows observations of **cumulative** infections at uneven time
intervals. To preserve Markov structure (and hence tractable inference)
this model adds an auxiliary compartment ``O`` denoting the fully-observed
Expand Down Expand Up @@ -325,15 +570,11 @@ def transition_bwd(self, params, prev, curr, t):

class UnknownStartSIRModel(CompartmentalModel):
"""
Susceptible-Infected-Recovered model with unknown date of first infection.
Generalizes :class:`SimpleSIRModel` by allowing unknown date of first
infection.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with three
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.
This model demonstrates:
1. How to incorporate spontaneous infections from external sources;
Expand Down Expand Up @@ -480,15 +721,11 @@ def predict(self, forecast=0):

class RegionalSIRModel(CompartmentalModel):
r"""
Susceptible-Infected-Recovered model with coupling across regions.
Generalizes :class:`SimpleSIRModel` to simultaneously model multiple
regions with weak coupling across regions.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with three
compartments in each region: "S" for susceptible, "I" for infected, and "R"
for recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.
Regions are coupled by a ``coupling`` matrix with entries in ``[0,1]``.
The all ones matrix is equivalent to a single region. The identity matrix
is equivalent to a set of independent regions. This need not be symmetric,
Expand Down
Loading

0 comments on commit 99222cf

Please sign in to comment.