-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a regional SIR model with approximate inference #2466
Conversation
@torch.no_grad() | ||
def align_samples(samples, model, particle_dim): | ||
""" | ||
Unsqueeze stacked samples such that their particle dim all aligns. | ||
This traces ``model`` to determine the ``event_dim`` of each site. | ||
""" | ||
assert particle_dim < 0 | ||
|
||
sample = {name: value[0] for name, value in samples.items()} | ||
with poutine.block(), poutine.trace() as tr, poutine.condition(data=sample): | ||
model() | ||
|
||
samples = samples.copy() | ||
for name, value in samples.items(): | ||
event_dim = tr.trace.nodes[name]["fn"].event_dim | ||
pad = event_dim - particle_dim - value.dim() | ||
if pad < 0: | ||
raise ValueError("Cannot align samples, try moving particle_dim left") | ||
if pad > 0: | ||
shape = value.shape[:1] + (1,) * pad + value.shape[1:] | ||
print("DEBUG reshaping {} : {} -> {}".format(name, value.shape, shape)) | ||
samples[name] = value.reshape(shape) | ||
|
||
return samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad I believe we can later refactor to make this generic, say as a new kwarg to MCMC.get_samples()
, maybe align_samples
(defaults to False) or particle_dim
(defaults to None, must be negative).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds reasonable. IIUC, we can put an outermost particle dim with size=1 for a model (if align_samples=True
), collect samples with our usual flow and concatenate the collected samples at dim 0 which should be the same as sampling from a vectorized model with an outermost plate dim. Would that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safer to first collect samples and event_dim
s and afterwards reshape. That way we wouldn't force users to write vectorizable / broadcastable code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I missed your comment. That sounds reasonable, but would sample reshaping be useful if the model wasn't vectorizable? Just want to make sure that I understand the use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes sample reshaping would still be useful even if the model is not vectorizable. For example in the CompartmentalModel
in this PR, there are three models that are mathematically equivalent but with different computational complexity. We run HMC on one of those modes that is not vectorizable over particles (it is instead vectorized over time, hence its name _vectorized_model()
). We then stitch together multiple samples and poutine.condition
two other model that are vectorized over particles but are sequential over time (_sequential_model()
and _generative_model()
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining, this makes sense. I'll go over the models and your reshaping utility.
|
||
# Account for infections from all regions. | ||
I_coupled = state["I"] @ self.coupling | ||
pop_coupled = self.population @ self.coupling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't the coupling only operate on I?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know, we should probably ask @lucymli what parameterization makes most sense. As this PR currently stands, coupling
need not be normalized, and I have aimed for the following properties to hold (but I may be wrong 😖):
coupling = torch.ones(R, R)
replicates the behavior of a single region of sizepopulation.sum()
.- If there is a single infectious individual among all regions, then the expected number of subsequent infections depends on
R0
but not oncoupling
. This implies that the more time I spend infecting Oakland the less time I spend infecting San Francisco.
I guess an alternative parameterization is for a pairwise R0 matrix, but this seems less plausible to me. What other parameterizations or properties seem be sensible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does align_samples
need tests?
Addresses #2426
This implements a discrete state SIR model with interaction among multiple regions. I'm pretty happy with the user-facing model code, but the inference code suffers from inscrutable dimension complexity.
Approximation
Because exact enumeration has exponential cost in the number of regions, this PR instead uses a point estimate for cross region infections, namely the
(-0.5, population+0.5)
-bounded auxiliary variable as suggested by @martinjankowiak . I believe this estimate is unbiased except at the edges where there are small edge effects.I'd like to test accuracy after this merges. This PR does introduce an interface to access the approximation,
prev["I_approx"]
, so we can later change the approximation under the hood without breaking user-facing model code.Tensor dimension bookkeeping
This PR required reordering dimensions to ensure proper broadcasting in user-facing model code. The new ordering is:
Previously enum dims were on the left for sequential enumeration but on the right for vectorized enumeration. Also included in this change is the new
t=slice(None)
value for time in the vectorized model (previouslyt
was a tuple with Ellipsis and Nones).Tested