-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add checkpoints during sampling #7503
Comments
It seems to me, the important thing is we need a pm.sample that can resume from a given trace/state info. I'm not sure if pm.sample should have the extra burden of check_points, that's something the user could easily cook up in an outer loop (and we can offer as a utility ) if the functionality to resume was there? I'm not sure how to interact with external samples, pm.sample there is basically just a gateway to the external samplers and doesn't do anything itself once those are launched. A first step would be for our samplers to return the internal state at the last step and also allow them to resume from an externally-stored internal state? |
Yes, the first step is to think of some kind of standardised mechanism that samplers should expose. From my perspective, we need:
Once that high level interface is defined, we can try to get external samplers to also conform to it. We can also later add some easy access utility to help people orchestrate this, saving the state somewhere. |
@ricardoV94, while looking into this, I ran into a potential problem. PyMC step methods are intertwined with the model object, its value variables and some compiled logp functions. Serializing the step methods might be possible, but it looks like it might be very hard to do without also serializing the model, getting it to work on a sort of copy of the model. This might not be a problem at all, but I´ll have to think it through a bit longer. |
The step samplers could (they kind of do already) take the model / logp function as input. |
My original idea for points 1 and 2 (dump and load the sampler's state) was to try and rely on the pickle standard approach. The problem with that is that if you try to dump the pickled sampler, you'd also have to dump the pickled model and whatever compiled function the step method had as an attribute. I'm almost certain that this approach will have a high memory footprint. I'm not sure if unpickling a step method will lead to problems related to having the step method point to a cloned version of the actual model that should get used. I would have loved for the step method to have weaker references (I don't mean the Since I'm a bit afraid that using the pickle approach might be bad, I'll try to add methods to already instantiated step methods that can return or set their sampling state. |
I think that's much more clean anyway |
What if the trace always has access to the last state of each step sampler in a dictionary (no need to even dump, only load)? |
Before
No response
After
Context for the issue:
If one has models that take very long to sample, it would be great to have a way to store the information of the
steppers
in a checkpoint file so that if something happens and sampling stops, we could pick up from where we left off. This is a very old feature request that is related to #292, #143 and #3661.Those issues talk about
iter_sample
that works as a generator that one could simply pause and resume later. The problem with that is that there is no access to the stepper's state. I think that we need two things to get the samplers warm started:Currently, most samplers and step methods provide some ways to get 1 but we never have access to 2. The current pymc samplers have a bunch of
KeyboardInterrupt
catches (here, here, here, and here). We could add a handling call there to also store the step method's state.nutpie
has the non-blocking sampling with anabort
function call whenKeyboardInterrupt
gets hit. We could maybe add a similar state recording thing there.blackjax
has its progress bar conditional steps which we could try to mimic to get the same effect.numpyro
has a similar thing going with the progress bar but it looks like it's way deeper than withblackjax
.All of this to say that I think that we need to define some kind of standard way for the samplers to provide their state information. The specific samplers would then have to conform to the standard using whatever internal things they need. For
pymc
samplers it would be some way to recreate the step methods (maybe using some kind of__setstate__
and__getstate__
), fornutpie
it would have to be some new datatype that could be sent into ruff, forblackjax
it could be the kernel and random keys. I think that the important thing is to get the standard approach to which samplers should conform to, and once we have those, we could build support for checkpoints and restarting sampling from them later.The text was updated successfully, but these errors were encountered: