Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements relaxed inference for compartmental models #2513

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def generate_data(args):
obs_sum = int(obs.sum())
new_I_sum = int(new_I[:args.duration].sum())
if obs_sum >= args.min_observations:
logging.info("Observed {:d}/{:d} infections:\n{}".format(
obs_sum, new_I_sum, " ".join(str(int(x)) for x in obs)))
logging.info("Observed {:d}/{:d} infections in population of {}:\n{}".format(
obs_sum, new_I_sum, args.population, " ".join(str(int(x)) for x in obs)))
return {"new_I": new_I, "obs": obs}

raise ValueError("Failed to generate {} observations. Try increasing "
Expand All @@ -76,6 +76,7 @@ def hook_fn(kernel, *unused):
num_quant_bins=args.num_bins,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
relax=args.relax,
hook_fn=hook_fn)

mcmc.summary()
Expand Down Expand Up @@ -246,8 +247,9 @@ def main(args):
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-a", "--arrowhead-mass", action="store_true")
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("-nb", "--num-bins", default=4, type=int)
parser.add_argument("--relax", action="store_true")
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("--double", action="store_true")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--verbose", action="store_true")
Expand Down
178 changes: 154 additions & 24 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.infer.util import is_validation_enabled
from pyro.util import warn_if_nan

from .distributions import set_approx_log_prob_tol, set_approx_sample_thresh
Expand Down Expand Up @@ -182,13 +183,26 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):

# Select the most probable hypothesis.
i = int(smc.state._log_weights.max(0).indices)
init = {key: value[i] for key, value in smc.state.items()}
guess = {key: value[i] for key, value in smc.state.items()}

# Fill in sample site values.
init = self.generate(init)
aux = torch.stack([init[name] for name in self.compartments], dim=0)
init["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5)
return init
guess = self.generate(guess)
if self.relaxed:
with poutine.block(), poutine.condition(data=guess):
params = self.global_model()
init = self.initialize(params)
flux = {}
state = {}
for c1, c2 in zip(self.compartments, self.compartments[1:] + ("R",)):
flux[c1] = guess["{}2{}".format(c1, c2)]
state[c1] = guess[c1]
aux = _differentiate_flux(self.compartments, init, flux, state)
guess["auxiliary"] = torch.stack([aux[name] for name in self.compartments])
else:
aux = torch.stack([guess[name] for name in self.compartments])
guess["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5)

return guess

def global_model(self):
"""
Expand Down Expand Up @@ -313,6 +327,7 @@ def fit(self, **options):
"""
# Parse options, saving some for use in .predict().
self.num_quant_bins = options.pop("num_quant_bins", 4)
self.relaxed = options.pop("relax", False)
haar = options.pop("haar", False)
assert isinstance(haar, bool)
haar_full_mass = options.pop("haar_full_mass", 0)
Expand All @@ -322,6 +337,10 @@ def fit(self, **options):
haar = haar or (haar_full_mass > 0)

# Heuristically initialize to feasible latents.
if self.relaxed:
aux_constraint = constraints.unit_interval
else:
aux_constraint = constraints.interval(-0.5, self.population + 0.5)
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options)
if k.startswith("heuristic_")}
Expand All @@ -335,7 +354,7 @@ def heuristic():
if haar:
# Also initialize Haar transformed coordinates.
x = init_values["auxiliary"]
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = biject_to(aux_constraint).inv(x)
x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True)(x)
init_values["auxiliary_haar"] = x
if haar_full_mass:
Expand All @@ -355,7 +374,7 @@ def heuristic():
logger.info("Running inference...")
max_tree_depth = options.pop("max_tree_depth", 5)
full_mass = options.pop("full_mass", self.full_mass)
model = self._vectorized_model
model = self._relaxed_model if self.relaxed else self._vectorized_model
if haar:
rep = HaarReparam(dim=-2 if self.is_regional else -1, flip=True)
model = poutine.reparam(model, {"auxiliary": rep})
Expand Down Expand Up @@ -389,12 +408,13 @@ def heuristic():
# Transform back from Haar coordinates.
x = self.samples.pop("auxiliary_haar")
x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True).inv(x)
x = biject_to(constraints.interval(-0.5, self.population + 0.5))(x)
x = biject_to(aux_constraint)(x)
self.samples["auxiliary"] = x

# Unsqueeze samples to align particle dim for use in poutine.condition.
# TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
self.samples = align_samples(self.samples, self._vectorized_model,
model = self._relaxed_model if self.relaxed else self._vectorized_model
self.samples = align_samples(self.samples, model,
particle_dim=-1 - self.max_plate_nesting)
return mcmc # E.g. so user can run mcmc.summary().

Expand All @@ -420,24 +440,25 @@ def predict(self, forecast=0):
particle_plate = pyro.plate("particles", num_samples,
dim=-1 - self.max_plate_nesting)

# Sample discrete auxiliary variables conditioned on the continuous
# variables sampled by _vectorized_model. This samples only time steps
# [0:duration]. Here infer_discrete runs a forward-filter
# backward-sample algorithm.
logger.info("Predicting latent variables for {} time steps..."
.format(self.duration))
model = self._sequential_model
model = poutine.condition(model, samples)
model = particle_plate(model)
model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
trace = poutine.trace(model).get_trace()
samples = OrderedDict((name, site["value"])
for name, site in trace.nodes.items()
if site["type"] == "sample")
if not self.relaxed:
# Sample discrete auxiliary variables conditioned on the continuous
# variables sampled by _vectorized_model. This samples only time steps
# [0:duration]. Here infer_discrete runs a forward-filter
# backward-sample algorithm.
logger.info("Predicting latent variables for {} time steps..."
.format(self.duration))
model = self._sequential_model
model = poutine.condition(model, samples)
model = particle_plate(model)
model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
trace = poutine.trace(model).get_trace()
samples = OrderedDict((name, site["value"])
for name, site in trace.nodes.items()
if site["type"] == "sample")

# Optionally forecast with the forward _generative_model. This samples
# time steps [duration:duration+forecast].
if forecast:
if forecast or self.relaxed:
logger.info("Forecasting {} steps ahead...".format(forecast))
model = self._generative_model
model = poutine.condition(model, samples)
Expand Down Expand Up @@ -606,6 +627,115 @@ def enum_reshape(tensor, position):

self._clear_plates()

def _relaxed_model(self):
C = len(self.compartments)
T = self.duration
R_shape = getattr(self.population, "shape", ()) # Region shape.
if R_shape:
raise NotImplementedError("TODO")

# Sample global parameters.
params = self.global_model()

# Sample the continuous reparameterizing variable.
shape = (C, T) + R_shape
auxiliary = pyro.sample("auxiliary",
dist.Uniform(0, 1)
.mask(False).expand(shape).to_event())
assert auxiliary.shape == shape, "particle plates are not supported"
auxiliary = dict(zip(self.compartments, auxiliary))

# Integrate flux.
init = self.initialize(params)
flux, state = _integrate_flux(self.compartments, init, auxiliary)

# Truncate from left and right.
prev = {k: v[..., :-1] for k, v in state.items()}
curr = {k: v[..., 1:] for k, v in state.items()}
for name in self.compartments:
pyro.deterministic(name, curr[name])

# Record transition factors.
t = slice(None)
compartments = self.compartments + ("R",)
data = {"{}2{}_{}".format(c1, c2, t): flux[c1]
for c1, c2 in zip(compartments, compartments[1:])}
with poutine.condition(data=data):
self.transition_fwd(params, prev, t)

# Validate that .transition_fwd() correctly updated prev to curr.
if is_validation_enabled():
if set(prev.keys()) != set(self.compartments):
raise ValueError("\n".join([
"Incorrect state update keys in .transition_fwd():",
"Expected: {}".format(set(self.compartments)),
"Actual: {}".format(set(prev.keys())),
]))
for key in self.compartments:
if not torch.allclose(prev[key], curr[key]):
raise ValueError("Incorrect state['{}'] update in .transition_fwd()"
.format(key))

self._clear_plates()


class _DifferentiableFloor(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.floor()

@staticmethod
def backward(ctx, grad):
return grad


def _quantize(x, total_count):
return _DifferentiableFloor.apply(x * (1 + total_count))


def _dequantize(x, total_count):
return (x + 0.5) / (1 + total_count)


def _integrate(init, aux_flux, shift=None):
"""
flux[t] = _quantize(aux_flux[t], state[t])
state[0] = init
state[t+1] = state[t] - flux[t] + shift[t]
"""
if not isinstance(init, torch.Tensor):
init = torch.as_tensor(float(init))
flux = []
state = [init.expand_as(aux_flux[0])]
# TODO vectorize
for t in range(aux_flux.size(-1)):
flux.append(_quantize(aux_flux[t], state[-1]))
curr = state[-1] - flux[-1]
if shift is not None:
curr = curr + shift[t]
state.append(curr)
flux = torch.stack(flux, dim=-1)
state = torch.stack(state, dim=-1)
return flux, state


def _integrate_flux(compartments, init, aux_flux):
flux = {}
state = {}
for c, name in enumerate(compartments):
flux[name], state[name] = _integrate(init[name], aux_flux[name],
flux[compartments[c - 1]] if c else None)
return flux, state


@torch.no_grad()
def _differentiate_flux(compartments, init, flux, curr):
aux_flux = {}
for name in compartments:
prev = cat2(init[name], curr[name][:-1], dim=-1)
aux_flux[name] = _dequantize(flux[name], prev)
return aux_flux


class _SMCModel:
"""
Expand Down
22 changes: 12 additions & 10 deletions pyro/contrib/epidemiology/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def transition_fwd(self, params, state, t):
# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I, rho),
obs=self.data[t] if t < self.duration else None)
obs=self.data[t] if isinstance(t, slice) or t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
R0, tau, rho = params
Expand Down Expand Up @@ -172,7 +172,7 @@ def transition_fwd(self, params, state, t):
# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2E, rho),
obs=self.data[t] if t < self.duration else None)
obs=self.data[t] if isinstance(t, slice) or t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
R0, tau_e, tau_i, rho = params
Expand Down Expand Up @@ -286,7 +286,7 @@ def transition_fwd(self, params, state, t):
# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I, rho),
obs=self.data[t] if t < self.duration else None)
obs=self.data[t] if isinstance(t, slice) or t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
R0, k, tau, rho = params
Expand Down Expand Up @@ -417,10 +417,11 @@ def transition_fwd(self, params, state, t):
concentration=k))

# Condition on observations.
observed_t = isinstance(t, slice) or t < self.duration
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2E, rho),
obs=self.data[t] if t < self.duration else None)
if self.coal_likelihood is not None and t < self.duration:
obs=self.data[t] if observed_t else None)
if self.coal_likelihood is not None and observed_t:
R = R0 * state["S"] / self.population
coal_rate = R * (1. + 1. / k) / (tau_i * state["I"] + 1e-8)
pyro.factor("coalescent_{}".format(t),
Expand Down Expand Up @@ -534,8 +535,8 @@ def transition_fwd(self, params, state, t):
state["O"] = state["O"] + S2O

# Condition on cumulative observations.
mask_t = self.mask[t] if t < self.duration else False
data_t = self.data[t] if t < self.duration else None
mask_t = self.mask[t] if isinstance(t, slice) or t < self.duration else False
data_t = self.data[t] if isinstance(t, slice) or t < self.duration else None
pyro.sample("obs_{}".format(t),
dist.Delta(state["O"]).mask(mask_t),
obs=data_t)
Expand Down Expand Up @@ -666,8 +667,9 @@ def transition_fwd(self, params, state, t):

# In .transition_fwd() t will always be an integer but may lie outside
# of [0,self.duration) when forecasting.
rho_t = rho[..., t] if t < self.duration else rho[..., -1]
data_t = self.data[t] if t < self.duration else None
observed_t = isinstance(t, slice) or t < self.duration
rho_t = rho[..., t] if observed_t else rho[..., -1]
data_t = self.data[t] if observed_t else None

# Condition on observations.
pyro.sample("obs_{}".format(t),
Expand Down Expand Up @@ -831,7 +833,7 @@ def transition_fwd(self, params, state, t):
# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I, rho),
obs=self.data[t] if t < self.duration else None)
obs=self.data[t] if isinstance(t, slice) or t < self.duration else None)

def transition_bwd(self, params, prev, curr, t):
R0, tau, rho = params
Expand Down
Loading