-
Notifications
You must be signed in to change notification settings - Fork 230
Move external sampler interface to AbstractMCMC #2704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
penelopeysm
wants to merge
10
commits into
breaking
Choose a base branch
from
py/tochains
base: breaking
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+106
−97
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
06752c4
Decouple external sampler interface from Turing
penelopeysm 8fec487
Delete a todo note
penelopeysm 46354fb
Changelog
penelopeysm c7ffada
Temp point to AbstractMCMC feature branch
penelopeysm eacffdb
Fix tests
penelopeysm 9b30bea
Don't remove the keyword argument
penelopeysm 0ebf0e8
Fix import order
penelopeysm c464b0c
Upstream `getstats` definitions to AdvancedMH/AdvancedHMC
penelopeysm e99c154
remove sources
penelopeysm d8dcd11
Bump test deps (probably pointless, but eh)
penelopeysm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| """ | ||
| ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} | ||
| ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType} | ||
|
|
||
| Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng, | ||
| ::DynamicPPL.Model, spl)`. | ||
|
|
@@ -14,45 +14,59 @@ $(TYPEDFIELDS) | |
| If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl | ||
| models, there are two options: | ||
|
|
||
| 1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the | ||
| most powerful option and is what Turing.jl's in-house samplers do. Implementing this | ||
| means that you can directly call `sample(model, MySampler(), N)`. | ||
| 1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. That is to | ||
| say, implement `AbstractMCMC.step(rng::Random.AbstractRNG, model::DynamicPPL.Model, | ||
| sampler::MySampler; kwargs...)` and related methods. This is the most powerful option and | ||
| is what Turing.jl's in-house samplers do. Implementing this means that you can directly | ||
| call `sample(model, MySampler(), N)`. | ||
|
|
||
| 2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This | ||
| struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step` | ||
| 2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel` (the | ||
| same signature as above except that `model::AbstractMCMC.LogDensityModel`). This struct | ||
| wraps an object that obeys the LogDensityProblems.jl interface, so your `step` | ||
| implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use | ||
| this with Turing.jl, you will need to wrap your sampler: `sample(model, | ||
| externalsampler(MySampler()), N)`. | ||
|
Comment on lines
-17
to
28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll write proper docs about this separately (in the docs repo). |
||
|
|
||
| This section describes the latter. | ||
|
|
||
| `MySampler` must implement the following methods: | ||
| `MySampler` **must** implement the following methods: | ||
|
|
||
| - `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is | ||
| documented in AbstractMCMC.jl) | ||
| - `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the | ||
| parameters from the transition returned by your sampler (i.e., the first return value of | ||
| `step`). There is a default implementation for this method, which is to return | ||
| `external_transition.θ`. | ||
|
|
||
| !!! note | ||
| In a future breaking release of Turing, this is likely to change to | ||
| `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. | ||
| `Turing.Inference.getparams` is technically an internal method, so the aim here is to | ||
| unify the interface for samplers at a higher level. | ||
| documented in AbstractMCMC.jl). This function must return a tuple of two elements, a | ||
| 'transition' and a 'state'. | ||
|
|
||
| - `AbstractMCMC.getparams(external_state)`: How to extract the parameters from the **state** | ||
| returned by your sampler (i.e., the **second** return value of `step`). For your sampler | ||
| to work with Turing.jl, this function should return a Vector of parameter values. Note that | ||
| this function does not need to perform any linking or unlinking; Turing.jl will take care of | ||
| this for you. You should return the parameters *exactly* as your sampler sees them. | ||
|
|
||
| - `AbstractMCMC.getstats(external_state)`: Extract sampler statistics corresponding to this | ||
| iteration from the **state** returned by your sampler (i.e., the **second** return value | ||
| of `step`). For your sampler to work with Turing.jl, this function should return a | ||
| `NamedTuple`. If there are no statistics to return, return `NamedTuple()`. | ||
|
|
||
| Note that `getstats` should not include log-probabilities as these will be recalculated by | ||
| Turing automatically for you. | ||
|
|
||
| Notice that both of these functions take the **state** as input, not the **transition**. In | ||
| other words, the transition is completely useless for the external sampler interface. This is | ||
| in line with long-term plans for removing transitions from AbstractMCMC.jl and only using | ||
| states. | ||
|
|
||
| There are a few more optional functions which you can implement to improve the integration | ||
| with Turing.jl: | ||
|
|
||
| - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as | ||
| a component in Turing's Gibbs sampler, you should make this evaluate to `true`. | ||
|
|
||
| - `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires | ||
| - `AbstractMCMC.requires_unconstrained_space(::MySampler)`: If your sampler requires | ||
| unconstrained space, you should return `true`. This tells Turing to perform linking on the | ||
| VarInfo before evaluation, and ensures that the parameter values passed to your sampler | ||
| will always be in unconstrained (Euclidean) space. | ||
|
|
||
| - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want to disallow your sampler | ||
| from a component in Turing's Gibbs sampler, you should make this evaluate to `false`. Note | ||
| that the default is `true`, so you should only need to implement this in special cases. | ||
| """ | ||
| struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: | ||
| struct ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType} <: | ||
| AbstractSampler | ||
| "the sampler to wrap" | ||
| sampler::S | ||
|
|
@@ -67,47 +81,42 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain | |
| # Arguments | ||
| - `sampler::AbstractSampler`: The sampler to wrap. | ||
| - `adtype::ADTypes.AbstractADType`: The automatic differentiation (AD) backend to use. | ||
| - `unconstrained::Val=Val{true}()`: Value type containing a boolean indicating whether the sampler requires unconstrained space. | ||
| - `unconstrained::Val`: Value type containing a boolean indicating whether the sampler requires unconstrained space. | ||
| """ | ||
| function ExternalSampler( | ||
| sampler::AbstractSampler, | ||
| adtype::ADTypes.AbstractADType, | ||
| (::Val{unconstrained})=Val(true), | ||
| sampler::AbstractSampler, adtype::ADTypes.AbstractADType, ::Val{unconstrained} | ||
| ) where {unconstrained} | ||
| if !(unconstrained isa Bool) | ||
| throw( | ||
| ArgumentError("Expected Val{true} or Val{false}, got Val{$unconstrained}") | ||
| ) | ||
| end | ||
| return new{typeof(sampler),typeof(adtype),unconstrained}(sampler, adtype) | ||
| return new{unconstrained,typeof(sampler),typeof(adtype)}(sampler, adtype) | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| requires_unconstrained_space(sampler::ExternalSampler) | ||
|
|
||
| Return `true` if the sampler requires unconstrained space, and `false` otherwise. | ||
| """ | ||
| function requires_unconstrained_space( | ||
| ::ExternalSampler{<:Any,<:Any,Unconstrained} | ||
| ) where {Unconstrained} | ||
| return Unconstrained | ||
| end | ||
|
|
||
| """ | ||
| externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true) | ||
| externalsampler( | ||
| sampler::AbstractSampler; | ||
| adtype=AutoForwardDiff(), | ||
| unconstrained=AbstractMCMC.requires_unconstrained_space(sampler), | ||
| ) | ||
|
|
||
| Wrap a sampler so it can be used as an inference algorithm. | ||
|
|
||
| # Arguments | ||
| - `sampler::AbstractSampler`: The sampler to wrap. | ||
|
|
||
| # Keyword Arguments | ||
| - `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use. | ||
| - `unconstrained::Bool=true`: Whether the sampler requires unconstrained space. | ||
| - `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation | ||
| (AD) backend to use. | ||
| - `unconstrained::Bool=AbstractMCMC.requires_unconstrained_space(sampler)`: Whether the | ||
| sampler requires unconstrained space. | ||
| """ | ||
| function externalsampler( | ||
| sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained::Bool=true | ||
| sampler::AbstractSampler; | ||
| adtype=Turing.DEFAULT_ADTYPE, | ||
| unconstrained::Bool=AbstractMCMC.requires_unconstrained_space(sampler), | ||
| ) | ||
| return ExternalSampler(sampler, adtype, Val(unconstrained)) | ||
| end | ||
|
|
@@ -128,30 +137,21 @@ end | |
| get_varinfo(state::TuringState) = state.varinfo | ||
| get_varinfo(state::AbstractVarInfo) = state | ||
|
|
||
| getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ | ||
| function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState) | ||
| return getparams(model, state.transition) | ||
| end | ||
| getstats(transition::AdvancedHMC.Transition) = transition.stat | ||
|
|
||
| getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params | ||
|
|
||
| # TODO: Do we also support `resume`, etc? | ||
| function AbstractMCMC.step( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler_wrapper::ExternalSampler; | ||
| sampler_wrapper::ExternalSampler{unconstrained}; | ||
| initial_state=nothing, | ||
| initial_params, # passed through from sample | ||
| kwargs..., | ||
| ) | ||
| ) where {unconstrained} | ||
| sampler = sampler_wrapper.sampler | ||
|
|
||
| # Initialise varinfo with initial params and link the varinfo if needed. | ||
| varinfo = DynamicPPL.VarInfo(model) | ||
| _, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params) | ||
|
|
||
| if requires_unconstrained_space(sampler_wrapper) | ||
| if unconstrained | ||
| varinfo = DynamicPPL.link(varinfo, model) | ||
| end | ||
|
|
||
|
|
@@ -166,16 +166,17 @@ function AbstractMCMC.step( | |
| ) | ||
|
|
||
| # Then just call `AbstractMCMC.step` with the right arguments. | ||
| if initial_state === nothing | ||
| transition_inner, state_inner = AbstractMCMC.step( | ||
| _, state_inner = if initial_state === nothing | ||
| AbstractMCMC.step( | ||
| rng, | ||
| AbstractMCMC.LogDensityModel(f), | ||
| sampler; | ||
| initial_params=initial_params_vector, | ||
| kwargs..., | ||
| ) | ||
|
|
||
| else | ||
| transition_inner, state_inner = AbstractMCMC.step( | ||
| AbstractMCMC.step( | ||
| rng, | ||
| AbstractMCMC.LogDensityModel(f), | ||
| sampler, | ||
|
|
@@ -185,13 +186,12 @@ function AbstractMCMC.step( | |
| ) | ||
| end | ||
|
|
||
| # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) | ||
| # The latter uses the state rather than the transition. | ||
| # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead | ||
| new_parameters = Turing.Inference.getparams(f.model, transition_inner) | ||
| new_parameters = AbstractMCMC.getparams(f.model, state_inner) | ||
| new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) | ||
| new_stats = AbstractMCMC.getstats(state_inner) | ||
| return ( | ||
| Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) | ||
| Turing.Inference.Transition(f.model, new_vi, new_stats), | ||
| TuringState(state_inner, new_vi, f), | ||
| ) | ||
| end | ||
|
|
||
|
|
@@ -206,16 +206,15 @@ function AbstractMCMC.step( | |
| f = state.ldf | ||
|
|
||
| # Then just call `AdvancedMCMC.step` with the right arguments. | ||
| transition_inner, state_inner = AbstractMCMC.step( | ||
| _, state_inner = AbstractMCMC.step( | ||
| rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs... | ||
| ) | ||
|
|
||
| # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) | ||
| # The latter uses the state rather than the transition. | ||
| # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead | ||
| new_parameters = Turing.Inference.getparams(f.model, transition_inner) | ||
| new_parameters = AbstractMCMC.getparams(f.model, state_inner) | ||
| new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) | ||
| new_stats = AbstractMCMC.getstats(state_inner) | ||
| return ( | ||
| Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) | ||
| Turing.Inference.Transition(f.model, new_vi, new_stats), | ||
| TuringState(state_inner, new_vi, f), | ||
| ) | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this type parameter earlier so that we can dispatch on it more easily.