Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a regional SIR model with approximate inference #2466

Merged
merged 31 commits into from
May 10, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented May 7, 2020

Addresses #2426

This implements a discrete state SIR model with interaction among multiple regions. I'm pretty happy with the user-facing model code, but the inference code suffers from inscrutable dimension complexity.

Approximation

Because exact enumeration has exponential cost in the number of regions, this PR instead uses a point estimate for cross region infections, namely the (-0.5, population+0.5)-bounded auxiliary variable as suggested by @martinjankowiak . I believe this estimate is unbiased except at the edges where there are small edge effects.

I'd like to test accuracy after this merges. This PR does introduce an interface to access the approximation, prev["I_approx"], so we can later change the approximation under the hood without breaking user-facing model code.

Tensor dimension bookkeeping

This PR required reordering dimensions to ensure proper broadcasting in user-facing model code. The new ordering is:

enum dims | particle dims | time | region

Previously enum dims were on the left for sequential enumeration but on the right for vectorized enumeration. Also included in this change is the new t=slice(None) value for time in the vectorized model (previously t was a tuple with Ellipsis and Nones).

Tested

  • unit testes for helpers
  • smoke tests
  • sanity check on an example:
$ python regional.py -r 3 -d 60 -f 30 --plot
Simulating from a RegionalSIRModel
Observed 395/691 infections:
1 0 1 1 0 1 0 1 2 1 2 1 2 1 3 1 3 3 0 1 1 1 2 1 1 1 2 2 1 1 4 4 3 4 7 3 5 10 10 10 8 5 4 8 4 12 8 7 5 4 7 10 5 6 7 5 13 3 6 7
INFO 	 Heuristically initializing...
INFO 	 Running inference...
Sample: 100%|███████████| 300/300 [04:10,  1.20it/s, step size=6.95e-02, acc. prob=0.926]

DEBUG reshaping auxiliary : torch.Size([200, 2, 60, 3]) -> torch.Size([200, 1, 2, 60, 3])

                       mean       std    median      5.0%     95.0%     n_eff     r_hat
            R0[0]      1.55      0.08      1.55      1.44      1.69     42.24      1.00
        rho_c1[0]      2.26      0.98      2.05      0.93      3.79      9.45      1.15
        rho_c0[0]      1.18      0.50      1.02      0.46      1.91      5.13      1.06
           rho[0]      0.85      0.03      0.85      0.81      0.90      4.01      1.69
           rho[1]      0.53      0.04      0.53      0.45      0.59     10.08      1.15
           rho[2]      0.91      0.05      0.92      0.85      0.98      3.21      2.47
 auxiliary[0,0,0]    997.98      0.56    998.01    997.06    998.80     33.82      1.08
 auxiliary[0,0,1]    999.69      0.37    999.78    999.11   1000.20      5.73      1.35
 auxiliary[0,0,2]    999.82      0.40    999.88    999.22   1000.45      7.89      1.03
 auxiliary[0,1,0]    998.00      0.70    998.10    996.56    998.82     13.63      1.08
 auxiliary[0,1,1]    999.29      0.41    999.27    998.72    999.99      3.41      1.59
 auxiliary[0,1,2]   1000.27      0.35   1000.43    999.75   1000.49      8.81      1.02
...

image

Comment on lines +55 to +78
@torch.no_grad()
def align_samples(samples, model, particle_dim):
"""
Unsqueeze stacked samples such that their particle dim all aligns.
This traces ``model`` to determine the ``event_dim`` of each site.
"""
assert particle_dim < 0

sample = {name: value[0] for name, value in samples.items()}
with poutine.block(), poutine.trace() as tr, poutine.condition(data=sample):
model()

samples = samples.copy()
for name, value in samples.items():
event_dim = tr.trace.nodes[name]["fn"].event_dim
pad = event_dim - particle_dim - value.dim()
if pad < 0:
raise ValueError("Cannot align samples, try moving particle_dim left")
if pad > 0:
shape = value.shape[:1] + (1,) * pad + value.shape[1:]
print("DEBUG reshaping {} : {} -> {}".format(name, value.shape, shape))
samples[name] = value.reshape(shape)

return samples
Copy link
Member Author

@fritzo fritzo May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad I believe we can later refactor to make this generic, say as a new kwarg to MCMC.get_samples(), maybe align_samples (defaults to False) or particle_dim (defaults to None, must be negative).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable. IIUC, we can put an outermost particle dim with size=1 for a model (if align_samples=True), collect samples with our usual flow and concatenate the collected samples at dim 0 which should be the same as sampling from a vectorized model with an outermost plate dim. Would that work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safer to first collect samples and event_dim s and afterwards reshape. That way we wouldn't force users to write vectorizable / broadcastable code.

Copy link
Member

@neerajprad neerajprad May 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed your comment. That sounds reasonable, but would sample reshaping be useful if the model wasn't vectorizable? Just want to make sure that I understand the use case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sample reshaping would still be useful even if the model is not vectorizable. For example in the CompartmentalModel in this PR, there are three models that are mathematically equivalent but with different computational complexity. We run HMC on one of those modes that is not vectorizable over particles (it is instead vectorized over time, hence its name _vectorized_model()). We then stitch together multiple samples and poutine.condition two other model that are vectorized over particles but are sequential over time (_sequential_model() and _generative_model()).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, this makes sense. I'll go over the models and your reshaping utility.

pyro/contrib/epidemiology/compartmental.py Show resolved Hide resolved

# Account for infections from all regions.
I_coupled = state["I"] @ self.coupling
pop_coupled = self.population @ self.coupling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the coupling only operate on I?

Copy link
Member Author

@fritzo fritzo May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, we should probably ask @lucymli what parameterization makes most sense. As this PR currently stands, coupling need not be normalized, and I have aimed for the following properties to hold (but I may be wrong 😖):

  1. coupling = torch.ones(R, R) replicates the behavior of a single region of size population.sum().
  2. If there is a single infectious individual among all regions, then the expected number of subsequent infections depends on R0 but not on coupling. This implies that the more time I spend infecting Oakland the less time I spend infecting San Francisco.

I guess an alternative parameterization is for a pairwise R0 matrix, but this seems less plausible to me. What other parameterizations or properties seem be sensible?

@fritzo fritzo added awaiting review and removed WIP labels May 8, 2020
@fritzo fritzo requested a review from martinjankowiak May 8, 2020 00:36
martinjankowiak
martinjankowiak previously approved these changes May 9, 2020
Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does align_samples need tests?

pyro/contrib/epidemiology/compartmental.py Outdated Show resolved Hide resolved
pyro/contrib/epidemiology/sir.py Show resolved Hide resolved
martinjankowiak
martinjankowiak previously approved these changes May 9, 2020
pyro/contrib/epidemiology/sir.py Show resolved Hide resolved
pyro/contrib/epidemiology/sir.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants