-
Notifications
You must be signed in to change notification settings - Fork 32
Accumulators, stage 1 #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading
* Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really)
Benchmark Report for Commit efc7c53Computer Information
Benchmark Results
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not reviewing actual code, just one high-level thought that struck me.
src/abstract_varinfo.jl
Outdated
function setlogp!!(vi::AbstractVarInfo, logp) | ||
vi = setlogprior!!(vi, zero(logp)) | ||
vi = setloglikelihood!!(vi, logp) | ||
return vi | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about this the other day and thought I may as well post now. The ...logp()
family of functions are no longer well-defined in a world where everything is cleanly split into prior and likelihood. (only getlogp
and resetlogp
still make sense) I think last time we chatted about it the decision was to maybe forward the others to the likelihood methods, but I was wondering if it's actually safer to remove them (or make them error informatively) and force people to use likelihood or prior as otherwise it risks introducing subtle bugs. Backward compatibility is important but if it comes at the cost of correctness I feel kinda uneasy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My hope was that we could deprecate them but provide the same functionality through the new functions, like above. It's a good question as to whether there are edge cases where they do not provide the same functionality. I think this is helped by the fact that PriorContext and LikelihoodContext won't exist, and hence one can't be running code where the expectation would be that ...logp()
would be referring to logprior or loglikelihood in particular. And I think as long as one expects to get the logjoint out of ...logp()
we can do things like above, shoving things into likelihood, and get the same results. Do you think that solves it and let's us use deprecations rather than straight-up removals, or do you see other edge cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like this is a case where setlogp is ill-defined:
DynamicPPL.jl/src/test_utils/varinfo.jl
Lines 47 to 62 in c7bdc3f
lp = getlogp(vi_typed_metadata) | |
varinfos = map(( | |
vi_untyped_metadata, | |
vi_untyped_vnv, | |
vi_typed_metadata, | |
vi_typed_vnv, | |
svi_typed, | |
svi_untyped, | |
svi_vnv, | |
svi_typed_ref, | |
svi_untyped_ref, | |
svi_vnv_ref, | |
)) do vi | |
# Set them all to the same values. | |
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) | |
end |
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.
Of course, we can fix this on our end - we would get and set logprior and loglikelihood manually, and we can grep the codebase to make sure that there are no other ill-defined calls to setlogp. We can't guarantee that other people will be similarly careful, though (and us or anyone being careful also doesn't guarantee that everything will be fixed correctly).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While looking for other uses of setlogp, I encountered this:
AdvancedHMC.Transition
only contains a single notion of log density, so it's not obvious to me how we're going to extract the prior and likelihood components from it 😓 This might require upstream changes to AdvancedHMC. Since the contexts will be removed, I suspect LogDensityFunction
also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).
(For the record, I'd be quite happy with making all of these changes!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.
It is inconsistent, but as long as the user only uses getlogp
, they would never see the difference, right? If some of logprior is accidentally stored in loglikelihood or vice versa, as long as one is using getlogp
and DefaultContext
that should be undetectable. What would be trouble is if someone mixes using e.g. setlogp!!
and getlogprior
, which would require adding calls to getlogprior
after upgrading to a version that has deprecated setlogp!!
, but probably people would end up doing that. Maybe the deprecation warning could say something about this?
Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).
Yeah, this sort of stuff will come up (and is coming up) in multiple places. Anything that explicitly uses PriorContext or LikelihoodContext would need to be changed to use LogPrior and LogLikelihood accumulators instead. I'm currently doing this for pointwiselogdensities
.
|
y = getindex_internal(vi, vn) | ||
f = from_maybe_linked_internal_transform(vi, vn, dist) | ||
x, logjac = with_logabsdet_jacobian(f, y) | ||
vi = accumulate_assume!!(vi, x, logjac, vn, dist) | ||
return x, vi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we deal with tempering of logpdf and such now that it happens in the leaf of the call stack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the past, we would do this by altering the logpdf
coming a the assume
higher up in the call tree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the needs of tempering? What does it need to alter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Granted, I've only spent about 30 minutes reading about it, but I don't see the need for tempering to have such fine-grained control over the logp emitted from each individual assume / observe call -- it seems sufficient to either globally modify the logprior or loglikelihood, which can still be done with the current approach (by evaluating the model and then extracting the logp components). I had a look through the Pigeons.jl codebase for example and it doesn't seem to need to hook into the tilde pipeline. I don't want to speak for MCMCTempering since Tor is here 😉 Happy to be corrected if I'm wrong though
Co-authored-by: Tor Erlend Fjelde <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't even call this 'request changes', just a bunch of thoughts. I haven't reviewed all the code yet, only the interface bits (abstract_varinfo.jl and accumulator.jl - helpfully they're first alphabetically!) but I didn't want to put like 50 comments at once. Happy to continue once we've worked some of these out.
Thanks both for the comments.
This isn't fully reflected in what's implemented in this PR, but I would like to get to a point where the meaning of If I actually manage to implement that, would you then be in favour of the renaming? I say "if", because there may be complications when interfacing with samplers that I don't see now, regarding what needs to happen before/after accumulation. |
As for renaming DefaultContext, we chatted about it yesterday and yes I wouldn't mind calling it AccumulatorContext if there was some other context that didn't run the accumulators. (Although, I think the same effect could be achieved by just using an empty tuple of accumulators.) I think we decided to leave it to a later PR though. |
y = getindex_internal(vi, vn) | ||
f = from_maybe_linked_internal_transform(vi, vn, dist) | ||
x, logjac = with_logabsdet_jacobian(f, y) | ||
vi = accumulate_assume!!(vi, x, logjac, vn, dist) | ||
return x, vi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Granted, I've only spent about 30 minutes reading about it, but I don't see the need for tempering to have such fine-grained control over the logp emitted from each individual assume / observe call -- it seems sufficient to either globally modify the logprior or loglikelihood, which can still be done with the current approach (by evaluating the model and then extracting the logp components). I had a look through the Pigeons.jl codebase for example and it doesn't seem to need to hook into the tilde pipeline. I don't want to speak for MCMCTempering since Tor is here 😉 Happy to be corrected if I'm wrong though
Co-authored-by: Penelope Yong <[email protected]>
# TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! | ||
# for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but | ||
# the accumulators in the VarInfo are plain floats, we error since we can't change the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're almost there, just a few things more to look at.
I have stared at this for quite a while and I understand the issue with map_accumulator!!, but I don't get how this relates to unflatten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comes about in situations where you've made a regular VarInfo where the log probs are floats, and then you e.g. give it to a LogDensityFunction which wants to evaluate your model with values that are of some float-look-a-like type such as ForwardDiff.Dual
. This will manifest itself by unflatten
getting called with a vector of Dual
s. Further down the line this will cause us to add stuff to log probs that is also Dual
s.
Ideally I would like a situation where at the point when I have e.g. a LogPrior{Float64}
and I want to add to it a LogPrior{Dual{Float64}}
this will naturally result in a LogPrior{Dual{Float64}}
(because we defer to the underlying type when adding LogPrior
s, and they know how to handle Float64 + Dual{Float64}
) and then that gets assigned to vi.accs
, making use of the fact that since all of our functions for accs are !!
we can make the assignment not-in-place if need be. How elegant!
But ThreadSafeVarInfo
rains on this parade by being unable to change the types of the accumulators, because it needs to update each sub-AccumulatorTuple for each thread independently. Grrrr.
The solution is to do what we used to do, which is to say already in unflatten
, before any splitting into multiple threads happens, that we force our log probs to be of the same element type as the vector x
that was given to unflatten. This is crude, and for instance means that if you deliberately made your x
be, say, Float32
s and your log probs be Float64
s, that's ruined here. The only other solution I could think of was to special case in unflatten
to only do this conversion if the element type is one of the AD tracer types, but then that would require enumerating them all and having all the AD packages as dependencies.
(I think the real solution is to make ThreadSafeVarInfo better, maybe with locks, but that's a different story, and may not be worth the effort.)
Does that help? I'll improve the comment once I know how to best improve it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several thoughts:
- If the problem behaviour is restricted to TSVI, could we keep the special-casing to TSVI e.g. with a special method for unflatten?
- In the long run it seems that we need to have some recognition that
T = ForwardDiff.Dual{Float64}
(i.e. evaluating grad(logp)) is orthogonal to fromT = Float32
(i.e. evaluating logp with different precision) -- and they could even be mix-and-matched i.e.T = ForwardDiff.Dual{Float32}
which means evaluating grad(logp) with different precision. convert_eltype isn't right now clever enough to make this differentiation, but it appears to me that we indeed do need to (and should) hardcode this. The lowest effort way seems to be to use extensions. I wonder if in the long term this is something that could go into DifferentiationInterface. - If it can't be solved easily, could you open an issue to track this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the problem behaviour is restricted to TSVI, could we keep the special-casing to TSVI e.g. with a special method for unflatten?
I suspect we might have situations where unflatten
is called first on a regular VI and only then it's converted to TSVI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #906 for more on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked through the rest of the files and, okay, I'll admit I haven't minutely checked every single line, but I did read and I don't think I saw anything that merited closer inspection.
Final comment is about the performance. Have you made any inroads into finding out why the small models are a lot worse? Does it matter? Should the mutable-accumulator investigation be a separate PR?
I figured out the slowness issue:
On v0.36.1:
The type instability was there already before, its consequences just got more severe with accumulators. I added a test that would have caught it. The whole usual benchmark suite, on this PR:
For contrast, copypasta of the results on current main from above:
Mildly curious about what's going on with LDA (it's a horrible model anyway), and I also still want to try the mutable accumulators thing, but mostly I think we are done with performance concerns. |
Making accumulators mutable seemed to harm speed, as did most of the |
For the renaming of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only last thought was that user-facing docs would be v useful, although I'm guessing it's already on your list.
# TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! | ||
# for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but | ||
# the accumulators in the VarInfo are plain floats, we error since we can't change the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several thoughts:
- If the problem behaviour is restricted to TSVI, could we keep the special-casing to TSVI e.g. with a special method for unflatten?
- In the long run it seems that we need to have some recognition that
T = ForwardDiff.Dual{Float64}
(i.e. evaluating grad(logp)) is orthogonal to fromT = Float32
(i.e. evaluating logp with different precision) -- and they could even be mix-and-matched i.e.T = ForwardDiff.Dual{Float32}
which means evaluating grad(logp) with different precision. convert_eltype isn't right now clever enough to make this differentiation, but it appears to me that we indeed do need to (and should) hardcode this. The lowest effort way seems to be to use extensions. I wonder if in the long term this is something that could go into DifferentiationInterface. - If it can't be solved easily, could you open an issue to track this?
Thanks for the great feedback! |
This is starting to take shape. It's too early for a review: Everything is undocumented, uncleaned, and some things are still broken. The base design is there though, and most tests pass (pointwiseloglikelihood and doctests being the exceptions), so @penelopeysm, @torfjelde, if you want to have an early look at where this is going, feel free. The most interesting files are accumulators.jl, abstract_varinfo.jl, and context_implementations.jl.
In addition to obvious things that still need doing (documentation, clean-up, new tests, adding deprecations, fixing pointwiseloglikehood), a few things I have on my mind:
getacc
and similar functions should take the type of the accumulator as the index, or rather the symbol returned byaccumulator_name
. Leaning towards latter, but the former is what's currently implemented.DefaultContext
toAccumulationContext
. Or something else? I'm not fixated on the term "accumulator".(tilde_)assume
and(tilde_)observe
has changed (they no longer returnlogp
), the whole stack of calls withintilde_obssume!!
should be revisited. In particular, I'm thinking of splitting anything sampling-related to a call oftilde_obbsume
withSamplingContext
, that then at the end callstilde_obssume
withDefaultContext
. This might be a separate PR though.metadata.order
be an accumulator as well. Probably needs to actually be in the same accumulator withNumProduce
, since the two go together. Probably a separate PR though.