Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate need for CompartmentalModel.transition_bwd() #2514

Merged
merged 6 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 84 additions & 36 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.infer.util import is_validation_enabled
from pyro.util import warn_if_nan

from .distributions import set_approx_log_prob_tol, set_approx_sample_thresh
Expand All @@ -43,9 +44,9 @@ class CompartmentalModel(ABC):
compartmental models.

Derived classes must implement methods :meth:`heuristic`,
:meth:`initialize`, :meth:`transition_fwd`, :meth:`transition_bwd`.
Derived classes may optionally implement :meth:`global_model` and override
the ``series`` attribute.
:meth:`initialize`, and :meth:`transition`. Derived classes may optionally
implement :meth:`global_model` and :meth:`compute_flows` and may override
the ``series`` and ``full_mass`` attributes.

Example usage::

Expand All @@ -54,8 +55,7 @@ class MyModel(CompartmentalModel):
def __init__(self, ...): ...
def global_model(self): ...
def initialize(self, params): ...
def transition_fwd(self, params, state, t): ...
def transition_bwd(self, params, prev, curr, t): ...
def transition(self, params, state, t): ...

# Run inference to fit the model to data.
model = MyModel(...)
Expand Down Expand Up @@ -88,9 +88,9 @@ def transition_bwd(self, params, prev, curr, t): ...
a tensor of each region's population in a regional model.
:type population: int or torch.Tensor
:param tuple approximate: Names of compartments for which pointwise
approximations should be provided in :meth:`transition_bwd`, e.g. if you
specify ``approximate=("I")`` then the ``prev["I_approx"]`` will be a
continuous-valued non-enumerated point estimate of ``prev["I"]``.
approximations should be provided in :meth:`transition`, e.g. if you
specify ``approximate=("I")`` then the ``state["I_approx"]`` will be a
continuous-valued non-enumerated point estimate of ``state["I"]``.
Approximations are useful to reduce computational cost. Approximations
are continuous-valued with support ``(-0.5, population + 0.5)``.
:param int num_quant_bins: Number of quantization bins in the auxiliary
Expand Down Expand Up @@ -219,50 +219,66 @@ def initialize(self, params):
raise NotImplementedError

@abstractmethod
def transition_fwd(self, params, state, t):
def transition(self, params, state, t):
"""
Forward generative process for dynamics.

This inputs a current ``state`` and stochastically updates that
state in-place.

Note that this method is called under two different interpretations.
Note that this method is called under multiple different
interpretations, including batched and vectorized interpretations.
During :meth:`generate` this is called to generate a single sample.
During :meth:`predict` thsi is called to forecast a batch of samples,
During :meth:`heuristic` this is called to generate a batch of sample
fritzo marked this conversation as resolved.
Show resolved Hide resolved
for SMC. During :meth:`fit` this is called both in vectorized form
(vectorizing over time) and insequential form (for a single time step);
fritzo marked this conversation as resolved.
Show resolved Hide resolved
both forms enumerate over discrete latent variables. During
:meth:`predict` this is called to forecast a batch of samples,
conditioned on posterior samples for the time interval
``[0:duration]``.

:param params: The global params returned by :meth:`global_model`.
:param dict state: A dictionary mapping compartment name to current
tensor value. This should be updated in-place.
:param int t: Time index.
:param t: A time-like index. During inference ``t`` may be either a
slice (for vectorized inference) or an integer time index. During
prediction ``t`` will be integer time index.
:type t: int or slice
"""
raise NotImplementedError

@abstractmethod
def transition_bwd(self, params, prev, curr, t):
def compute_flows(self, prev, curr, t):
"""
Backward factor graph for dynamics.
Computes flows between compartments, given compartment populations
before and after time step t.

This inputs hypotheses for two subsequent time steps
(``prev``,``curr``) and makes observe statements
``pyro.sample(..., obs=...)`` to declare probability factors.
The default implementation assumes sequential flows terminating in an
implicit compartment named "R". For example if::

Note that this method is called under two different interpretations.
During inference it is called vectorizing over time but with a single
sample. During prediction it is called sequentially for each time
step, but always vectorizing over samples.
compartment_names = ("S", "E", "I")

:param params: The global params returned by :meth:`global_model`.
:param dict prev: A dictionary mapping compartment name to previous
tensor value. This should not be modified.
:param dict curr: A dictionary mapping compartment name to current
tensor value. This should not be modified.
:param t: A time-like index. During inference ``t`` will be
an indexing tuple that reshapes data tensors. During prediction
``t`` will be an actual integer time index.
the default implementation computes at time step ``t=9``

flows["S2E_9"] = prev["S"] - curr["S"]
flows["E2I_9"] = prev["E"] - curr["E"] + flows["S2E_9"]
flows["I2R_9"] = prev["I"] - curr["I"] + flows["E2I_9"]

For more complex flows (non-sequential, branching, looping,
duplicating, etc.), users may override this method.

:param dict state: A dictionary mapping compartment name to current
tensor value. This should be updated in-place.
:param t: A time-like index. During inference ``t`` may be either a
slice (for vectorized inference) or an integer time index. During
prediction ``t`` will be integer time index.
:type t: int or slice
"""
raise NotImplementedError
flows = {}
flow = 0
for source, destin in zip(self.compartments, self.compartments[1:] + ("R",)):
flow = prev[source] - curr[source] + flow
flows["{}2{}_{}".format(source, destin, t)] = flow
return flows

# Inference interface ########################################

Expand Down Expand Up @@ -481,6 +497,25 @@ def _concat_series(self, samples, forecast=0):
series = torch.broadcast_tensors(*map(torch.as_tensor, series))
samples[name] = torch.stack(series, dim=-2 if self.is_regional else -1)

def _transition_bwd(self, params, prev, curr, t):
"""
Helper to collect probabilty factors from .transition() conditioned on
previous and current enumerated states.
"""
# Run .transition() conditioned on computed flows.
flows = self.compute_flows(prev, curr, t)
with poutine.condition(data=flows):
state = prev.copy()
self.transition(params, state, t) # Mutates state.

# Validate that .transition() matches .compute_flows().
if is_validation_enabled():
for key in self.compartments:
if not (state[key] - curr[key]).eq(0).all():
raise ValueError("Incorrect state['{}'] update in .transition(), "
"check that .transition() matches .compute_flows()."
.format(key))

def _generative_model(self, forecast=0):
"""
Forward generative model used for simulation and forecasting.
Expand All @@ -495,7 +530,9 @@ def _generative_model(self, forecast=0):

# Sequentially transition.
for t in range(self.duration + forecast):
self.transition_fwd(params, state, t)
for name in self.approximate:
state[name + "_approx"] = state[name]
self.transition(params, state, t)
with self.region_plate:
for name in self.compartments:
pyro.deterministic("{}_{}".format(name, t), state[name])
Expand Down Expand Up @@ -537,7 +574,7 @@ def _sequential_model(self):
if name in self.approximate:
curr[name + "_approx"] = aux
prev.setdefault(name + "_approx", prev[name])
self.transition_bwd(params, prev, curr, t)
self._transition_bwd(params, prev, curr, t)

self._clear_plates()

Expand Down Expand Up @@ -600,8 +637,8 @@ def enum_reshape(tensor, position):
# Record transition factors.
with poutine.block(), poutine.trace() as tr:
with pyro.plate("time", T, dim=-1 - self.max_plate_nesting):
t = slice(None) # Used to slice data tensors.
self.transition_bwd(params, prev, curr, t)
t = slice(0, T, 1) # Used to slice data tensors.
self._transition_bwd(params, prev, curr, t)
tr.trace.compute_log_prob()
for name, site in tr.trace.nodes.items():
if site["type"] == "sample":
Expand Down Expand Up @@ -641,8 +678,19 @@ def init(self, state):
def step(self, state):
with poutine.block(), poutine.condition(data=state):
params = self.model.global_model()

with poutine.trace() as tr:
self.model.transition_fwd(params, state, self.t)
# Temporarily extend state with approximations.
extended_state = dict(state)
for name in self.model.approximate:
extended_state[name + "_approx"] = state[name]

self.model.transition(params, extended_state, self.t)

for name in self.model.approximate:
del extended_state[name + "_approx"]
state.update(extended_state)

for name, site in tr.trace.nodes.items():
if site["type"] == "sample" and not site["is_observed"]:
state[name] = site["value"]
Expand Down
Loading