-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a SplitReparam and use it in contrib.epidemiology #2495
Conversation
prototype_model = poutine.trace(InitMessenger(init_strategy)(model)) | ||
model_trace = prototype_model.get_trace(*model_args, **model_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi could you please review changes in this file? The changes are needed to
- avoid calling
.sample()
in constructing the prototypemodel_trace
, and - avoid duplicating expensive initialization work by reusing that
model_trace
as the first trace in_find_valid_initial_params()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, then with e.g. init_to_uniform
strategy, we still call sample but won't for some other strategies? If so, the change looks great to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, init_to_uniform
still calls sample. Thanks for reviewing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! (didn't review mcmc refactor)
return torch.zeros(()).expand(batch_shape) | ||
|
||
def sample(self, sample_shape=torch.Size()): | ||
raise NotImplementedError("SplitReparam does not support sampling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _ImproperUniform
does not support sampling
I think we can move this to the main distributions
module. Probably with an sample_fn
arg to generate prototype samples for HMC. We might add a warning to the docs mention that sample_fn
does not actually generate uniform samples in the support, but only to generate protype values for the inference. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _ImproperUniform does not support sampling
This user-facing exception mentions the user-facing class; _ImproperUniform
is an implementation detail.
I think we can move this to the main distributions module.
I feel like this is too bespoke for general use. Until we find another use case I'd prefer to keep it private. We could implement this other ways, e.g. Delta(nan).mask(False)
would raise a Nan error rather than a NotImplementedError
; I found this way helpful because it raised an exception early.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with an sample_fn arg to generate prototype samples for HMC.
I believe this functionality is already cleanly accomplished by InitMessenger
. Do you have another use case in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, I think dist.Foo(...).mask(False)
can serve the same purpose and initial values can be drawn directly from Foo
distribution. This class is useful when we don't know the support before-hand.
Addresses #2426
This adds a
SplitReparam
reparameterizer to split a sample site tensor into multiple other tensors, as suggested by @fehiepsi. The motivating use case (also implemented in this PR) is splitting Haar-reparameterized sample sites into low- and high-frequency parts, then adding the low frequency parts to thefull_mass
matrix in HMC.Note this reparameterizer is quite limited: it cannot generate samples because there is no standard way to split a distribution into multiple independent distribution. However in HMC and SVI require only
.log_prob()
to be implemented, not.sample()
. Actually this PR needed to change HMC internals to avoid calling.sample()
inadvertently during prototype tracing.This PR also rebalances some test since
unit
was taking 37 minutes vsintegration_batch_1
taking 15 minutes.Tested