-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add example of sparsely observed SIR model #2457
Conversation
mask_t = self.mask[t] if t < self.duration else False | ||
data_t = self.data[t] if t < self.duration else None | ||
pyro.sample("obs_{}".format(t), | ||
dist.Delta(state["O"]).mask(mask_t), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the auxiliary necessary? or does this just play better with the structure of CompartmentalModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl;dr The auxiliary variable is needed to preserve Markov structure.
The observations in this model are aggregated over intervals: obs=S2I[t_prev+1:t_curr+1].sum()
where t_prev
is the time of the last observation and t_curr
is the time of the current observation. In our enumeration strategy, this would couple all t_curr-t_prev
-many enumeration variables, growing exponentially in the number of variables. While the non-parallel-scan enumeration strategy could handle this without erroring, it would be prohibitively expensive, and would not allow e.g. large gaps in sensor data (as e.g. when a government shuts down or runs out of tests). The trick we're using is to add an auxiliary variable for the entire cumulative observation trajectory (with the same likelihood as in the usual SIR models), and then Delta
-clamp that auxiliary to the true observations at a few sparse time steps. This makes more work for HMC adds one enumeration variable per time step and increases the complexity of variable elimination by a constant factor of Q**2
, but crucially this factor is independent of gap size.
I had been struggling with this issue for a while since Lucy's model simulates 4 times per day but is observed only once. The only alternative I could see was to do parallel-scan variable elimination where each DiscreteHMM
state covered the joint distribution over an entire day (four time steps), resulting in complexity Q**(2 * 4 * 2)
for an SIR model or Q**(3 * 4 * 2)
for an SEIR model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explanation. to clarify though: if all you had was occasional missing data you wouldn't need this construction. this is really for the cumulative case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. It appears the cumulative case is more common in epidemiology.
mask_t = self.mask[t] if t < self.duration else False | ||
data_t = self.data[t] if t < self.duration else None | ||
pyro.sample("obs_{}".format(t), | ||
dist.Delta(state["O"]).mask(mask_t), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explanation. to clarify though: if all you had was occasional missing data you wouldn't need this construction. this is really for the cumulative case
Addresses #2426
This adds an example model with sparsely observed cumulative infections. This model is interesting because it preserves Markov structure by adding an auxiliary variable for fully observed cumulative infections (as discussed with @eb8680).
@eb8680 two notes:
examples/contrib/epidemiology/epi_phy.py
that has a complete complex model and script. That is, I think it would be most educational to illustrate each model feature independently in thepyro.contrib.epidemiology
module, and then combine them inexamples/
.S2O
flow looks a little like a chemical reaction with multiple products.Tested