-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support HaarReparam for non-compartmental variables #2523
Conversation
@@ -868,3 +849,77 @@ def init(self, state): | |||
def step(self, state): | |||
with poutine.block(hide_types=["observe"]): | |||
super().step(state.copy()) | |||
|
|||
|
|||
class _HaarSplitReparam: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi I found this class provides useful helpers to convert between user <--> aux coordinates, in addition to our usual .reparam()
functionality. These extra helpers seem necessary for poutine.reparam()
to play well with init_to_value()
. That is, if we reparameterize a model, then a user's custom init_to_value()
function will need to be converted to different coordinates. WDYT of generalizing these to be part of the Reparam
interface, say as a pair of methods (.aux_to_user()
, .user_to_aux()
) or similar? (in follow-up PRs to Pyro and NumPyro).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These extra helpers seem necessary for poutine.reparam() to play well with init_to_value().
I see. I agree that those helper methods would simplify much of the user code for this init strategy.
"HeterogeneousSIRModel", | ||
"OverdispersedSEIRModel", | ||
"OverdispersedSIRModel", | ||
"RegionalSIRModel", | ||
"SimpleSEIRModel", | ||
"SimpleSIRModel", | ||
"SparseSIRModel", | ||
"SuperspreadingSEIRModel", | ||
"SuperspreadingSIRModel", | ||
"UnknownStartSIRModel", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing these because they are inteded more as examples than as reusable components.
dist.Uniform(-0.5, self.population + 0.5) | ||
.mask(False).expand(shape).to_event()) | ||
assert auxiliary.shape == shape, "particle plates are not supported" | ||
# Split tenors into current state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# Transform to Haar coordinates. | ||
config = {} | ||
for name, dim in self.dims.items(): | ||
config[name] = HaarReparam(dim=dim, flip=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make flip
an arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's unnecessary since we don't expose any plumbing to set that, and since _HaarSplitReparam
is merely a minimal internal helper.
Thanks for reviewing! |
Addresses #2426
This follows up #2517 to support
HaarReparam
andSplitReparam
for non-compartmental latent time series. After this PRhaar_full_mass
will trigger learning of covariance among the low frequency components of all time series, hopefully improving mixing in long-duration datasets.Summary:
haar
andhaar_full_mass
logic into a_HaarSplitReparam
class.._sample_auxiliary()
method.HeterogeneousRegionalSIRModel
(ported from Add epidemiology tutorial with a regional SEIR model #2518) to exercise new code.Re
toRt
which is more standard.pytest_cov
breakage.From @martinjankowiak's comment in #2517 (review):
I hope the refactoring in this PR helps to reduce complexity.
Tested
HeterogeneousRegionalSIRModel
with test