-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add step method state and make step results deterministic with respect to it #7508
Conversation
8782ed3
to
8c44fd0
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7508 +/- ##
==========================================
+ Coverage 92.43% 92.69% +0.25%
==========================================
Files 103 104 +1
Lines 17109 17402 +293
==========================================
+ Hits 15814 16130 +316
+ Misses 1295 1272 -23
|
c84b4d1
to
cac126e
Compare
I think that this PR now has a good coverage for the first subgoal that I talked about in the description. All of pymc step methods now expose a
I'll ping some people to get some feedback on the current situation, because I've invested quite a lot of time into getting it right (and also have a clean git history), and I would like to get reviews before moving on. Pinging @ricardoV94, @michaelosthege, @Armavica, @ColCarroll and @junpenglao. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Luciano, had a quick glance. Looks alright but I wonder if we can make it a tad simpler by using dataclasses instead of reinventing something that seems similar?
I left other questions as comments.
Thanks for the initiative, making step methods less opaque will be great for customizing/controlling sampling
pymc/step_methods/hmc/base_hmc.py
Outdated
_num_divs_sample: int | ||
|
||
|
||
class BaseHMC(GradientSharedStep, WithSamplingState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a separate class? Why not be part of the baseclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which baseclass? WithSamplingState
? Or BaseHMCState
? If you mean the latter, it's because the BaseHMC
step method has properties that are different from other step methods, so I need to represent its state differently. If you mean the WithSamplingState
, that's because the WithSamplingState
provides the sampling_state
property accessors to the step method. Let me know if you meant something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean all step methods should have whatever WithSamplingState implements in the base class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I think you’re right. It must be a left over from a previous state of my commits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider a mixin to be a cleaner design, because it doesn't make third party step methods inherit WithSamplingState
even if they aren't compatible. It should also make things easier to test by not introducing cross-dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any compatibility issues it a step sampler spent make use of the functionality in the base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see what @michaelosthege is saying. The current design is more inline with what @ricardoV94 said, I already have a BlockedStep
as a subclass of WithSamplingState
, so including it in BaseHMC
ancestors explicitly is pointless. The base StepMethodState
only tries to access the rng
property. I could provide a default value factory for that property, and that way there won't be any problem for third party libraries that define their own step methods. They would simply get a useless random number Generator
object. They could use it if they wanted to though.
pymc/step_methods/state.py
Outdated
import numpy as np | ||
|
||
|
||
class MetaDataClassState(type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already said it somewhere else but I'll add it here too. I wanted to make the step method states simple dataclass
wrapped classes. I ran into problems revolving around __eq__
and then also around positional arguments being defined after arguments with default values. To avoid these problems, I had to add some arguments to the dataclass
decorator and I would sometimes forget them, leading to errors down the line. The simplest solution that didn't involve having to always write down boilerplate code with every State
definition was to add this metaclass. It simply creates the subclass type and then wraps that using the dataclass(eq=False, kw_only=True)
decorator. That way, you just need to inherit from DataClassState
and you'll get the guarantee of working with a dataclass
that has the specially crafted __eq__
method.
pymc/step_methods/state.py
Outdated
return dataclass(eq=False, kw_only=True)(super().__new__(cls, name, bases, namespace)) | ||
|
||
|
||
class DataClassState(metaclass=MetaDataClassState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Can we do simpler than this?
Can we just use a vanilla dataclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already answered this above. This solution avoids having to add boilerplate code all around the State
definitions. The main problem is that if you have a class that inherits from another class that uses the @dataclass
decorator, the subclass wont be a "dataclass
". With the metaclass approach, the subclasses will also be "dataclass
" types and they will use the nice __eq__
method that I had to write here.
kwargs[field.name] = deepcopy(_val) | ||
return state_class(**kwargs) | ||
|
||
@sampling_state.setter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is the setter used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment, it's only used in the tests. Calling step_method.sampling_state = state
will set the step method's attributes to the proper values represented by the provided state
. The goal is that we'll use it when we want to jump start the sampler to some past sampling state. The workflow I have in mind is this:
- Have the model already built
- Have the samplers already built
- If someone provides a past sampling state from which to start, set the sampler's state to that
- Start sampling using the sampler's set state and collect the results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer an explicit method for that functionality? Something like step_method.set_state(state)
?
It's 100% subjective opinion though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree to disagree on this one. The set_state
method looks like any other method call and it’s not explicitly saying that it won’t have any return value. The attribute assignment syntax on the other hand is much more explicit on its intent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So why is there a set_rng? The same argument you used against set_state would apply? Also, in my experience properties have always be a PITA because at some point we figure out we would like kwargs/customization and there's no way to refactor a property into a method with back_compat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a set_rng
method because the subclasses are supposed to overload it. The HMC step methods have their own rng
and their potential
also have a spawned rng
. Setting the rng
had to work differently than with the rest of the step methods, so I decided to make it a method instead of a property
. Anyway, I still prefer the sampling_state
as a property mixin. If we eventually realize that we don't want to, we can always add a new set_state
method and issue a DeprecationWarning
or RuntimeError
in the property setter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try one last argument. Setting a property (which looks like an attribute) does not make me intuitively think that it will actually affect the sampler. Specially in this case, where the property (sampler.sample_sate) is actually a read-only copy of the internal state, not the state itself.
It's like if you have a PyMC model, which has the attribute model.datalogp
. I wouldn't expect model.datalogp = 0
to be a valid way of overriding the model logp. Of course if I read the source code I'll be able to figure out, but from just reading use code I wouldn't think it would do what I want it to do.
I find sampler.set_rng()
must more obvious that will actually affect the rng used. And sampler.set_state()
that it will actually affect the state used.
Final argument I found, but not necessarily care about has to do with inheritance. Calling super()
on a property is clumsy, if you want to combine the effects of the base class method and some tweaking in the inherited class.
|
||
def get_random_generator(seed: RandomGenerator = None, copy=True) -> np.random.Generator: | ||
if copy: | ||
seed = deepcopy(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When/why is copying the seed needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When seed
is a numpy.random.Generator
. If you don't do that, numpy.random.default_rng(seed)
will return seed
. If you provide a BitGenerator
, the Generator
object itself will be new, but its BitGenerator
will be the same object that you passed in, making it potentially shared with another Generator
. To ensure that those two scenarios wont happen, I deepcopy
the seed
by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an explicit instance check for Generator? And/or comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ll add a comment but I found that numpy does all of the hard type checking work and it felt like a waste to repeat
This is a desired feature not a problem. There's an old issue that you can link to/ close if we get rid of the global RNG with this PR
How? Don't they always require interleaving steps (ie conditoned on a valid state?) |
Thanks for the review @ricardoV94! I am using |
Ouch, it looks like I'm a bit of a broken record... I opened #5797 more than 2 years ago.
At the moment, step methods need their own global random state to be able to run concurrently. That means that the samplers were limited to using different processes with their own random state to work. I'm not sure how |
Why do we need eq? |
For convenience. If we need to assert equality, it’s much better to have this method |
Why not wait until we see a need then? |
I did need it in all of the tests I wrote |
That's more an argument for a test utility than code we need to strictly maintain. Checking numpy array and random generator equality shows up in other scenarios |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! The refactoring of sampler RNGs could be extracted into its own PR and merged first?
Item 4. of your description sounds unrelated.
If I understand correctly, your approach with the mixin class has the nice benefit of not adding overhead to every iteration!
Regarding item 2. (where to dump traces) I previously pointed at the stats
, because they already contain some/many state fields. In McBackend the stats can be sparse>, so one could emit a sampler_state
state every 100 iterations or so.
- Last time I checked ArviZ/xarray didn't support saving sparse arrays to disk. But that's for sure workaroundable at the ArviZ level.
ClickHouseBackend
can persist sparse stats already.
@@ -292,7 +295,7 @@ def test_step_discrete(self): | |||
unc = np.diag(C) ** 0.5 | |||
check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0)) | |||
with model: | |||
step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal) | |||
step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal, rng=123456) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this seed different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeding these tests was a PITA. I kept running into sporadic errors and flakiness while I was polishing the step method detachment from the global random state, and I some rng's were left with the weird intermediate seeds.
@@ -36,6 +36,8 @@ | |||
from tests.helpers import RVsAssignmentStepsTester, StepMethodTester | |||
from tests.models import mv_simple, mv_simple_discrete, simple_categorical | |||
|
|||
SEED = sum(ord(c) for c in "test_metropolis") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but why 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pymc/step_methods/state.py
Outdated
this_fields = set([f.name for f in fields(self)]) | ||
other_fields = set([f.name for f in fields(other)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inner list comprehension is unnecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not? The fields
function will return Field
objects, that have a bunch of extra dataclass specific attributes (e.g. type
, default
, default_factory
). I just want to check that the names are the same and use those names later for getattr
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! I think I understand what you're saying. A set
of Field
objects should already be enough to test if this_fields == other_fields
because all of the other attributes should also match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually that was not my point (but you might be right about it)
set(generator)
aka set(a for a in "ABCD")
works. You can leave out creating the inner list (and then iterating it again when creating the set)
return v1 == v2 | ||
|
||
|
||
class WithSamplingState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add docstrings for the three new classes to explain how they fit together?
For WithSamplingState
I understand that it's a mixin adding a sampling_state
property which, upon access, returns a new container object of a DataClassState
subtype. This container holds copies of field values of the WithSamplingState
object. (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly right. I'll add comments and docstrings.
What overhead? There's nothing computed every iteration |
cac126e
to
fb80ed5
Compare
Thanks @michaelosthege! I think that it could be extracted. I would need to add docstring entries for
Thanks for the pointers. I'll try to use that once I arrive at point 2 |
I think that Michael means that I'm not building a |
The stats are just |
fb80ed5
to
d02a5b7
Compare
aa4f007
to
92d0845
Compare
137ad63
to
525f58a
Compare
@michaelosthege, @ricardoV94. I may have to pivot to doing other stuff for some time. I wouldn't want for this PR to become stale and difficult to rebase onto a future state of main. I think that what has been implemented so far is good enough to merge into main. The main points are:
All of this work is a major change from the current state of afairs in
I'll update this PR's description and mark it as ready for review and if you guys agree, we'll continue with the full checkpoint support in a near future. |
Sounds good to me, can you add a more informative PR title? |
BTW I still favor going with dataclasses instead of the custom new classes and have the equality as a detached test utility. There is no functionality that depends on equality right now or in the foreseeable future, unless I missed something. Complexity for testing purposes seems backwards to me. |
pymc/step_methods/state.py
Outdated
return False | ||
if isinstance(v1, (list, tuple)): # noqa: UP038 | ||
return len(v1) == len(v2) and all( | ||
DataClassState.compare_values(v1i, v2i) for v1i, v2i in zip(v1, v2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use zip(..., strict=True) and avoid explicitly comparing the length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually wasn’t good because strict=True
raises a ValueError
, and I just want it to return False
. I’ll keep the length comparison as it was
I think that I can have a workaround for part of this, but it really depends on what you mean by dataclasses. Is the problem that I'm using a metaclass? Is it that I'm defining an If your problem is that I'm relying on metaclasses, I think that I can do something differently to avoid them. If the problem is the |
Yes my problem is the implementation of |
8df720e
to
112862f
Compare
@ricardoV94, I changed the code to avoid the |
112862f
to
de72375
Compare
@lucianopaz sounds good to me. Besides my personal preference for methods vs setter, I have one last suggestion and one question. Rename What's the deal with frozen fields? Why do we need them / to worry about them? To be clear, I'm happy with the state and I am not blocking the merge after the rebase. |
Good point, I'll do that.
The step methods have a bunch of extra information in them that gets set when they are created. My very first idea was to try to have some kind of The long term goal was to be able to set the step method to a state where it could continue sampling as it had been doing before. Since I wouldn't be able to rebuild the full step method from what I save to disk, I needed to add some way to determine that the stored state was compatible with the step method that was being modified. That's why I decided to include some step information that doesn't change during sampling as |
de72375
to
af74f2c
Compare
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously. Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously. Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously. Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
Description
The original intent of this PR was to fully address #7503. That proved to be a very long task, so the current PR focuses only on closing #5797 and to add
sampling_state
to all pymc step methods. I'll leave the past description further down.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7508.org.readthedocs.build/en/7508/
Original intent (OUTDATED)
This will be a long PR, but I want to open the discussion of the design in its early phases. The overall goal is to provide the ability to pause and later resume sampling that is based on pymc step methods. Once this PR is finished, I hope that we'll get into the problem of adding this ability when sampling with
blackjax
,nutpie
andnumpyro
.There will be 4 subgoals in this PR. I'll write them down and list the tasks that I'll do in each: