11"""
2- ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained }
2+ ExternalSampler{Unconstrained, S<:AbstractSampler,AD<:ADTypes.AbstractADType}
33
44Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng,
55::DynamicPPL.Model, spl)`.
@@ -14,45 +14,59 @@ $(TYPEDFIELDS)
1414If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl
1515models, there are two options:
1616
17- 1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the
18- most powerful option and is what Turing.jl's in-house samplers do. Implementing this
19- means that you can directly call `sample(model, MySampler(), N)`.
17+ 1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. That is to
18+ say, implement `AbstractMCMC.step(rng::Random.AbstractRNG, model::DynamicPPL.Model,
19+ sampler::MySampler; kwargs...)` and related methods. This is the most powerful option and
20+ is what Turing.jl's in-house samplers do. Implementing this means that you can directly
21+ call `sample(model, MySampler(), N)`.
2022
21- 2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This
22- struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
23+ 2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel` (the
24+ same signature as above except that `model::AbstractMCMC.LogDensityModel`). This struct
25+ wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
2326 implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use
2427 this with Turing.jl, you will need to wrap your sampler: `sample(model,
2528 externalsampler(MySampler()), N)`.
2629
2730This section describes the latter.
2831
29- `MySampler` must implement the following methods:
32+ `MySampler` ** must** implement the following methods:
3033
3134- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is
32- documented in AbstractMCMC.jl)
33- - `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the
34- parameters from the transition returned by your sampler (i.e., the first return value of
35- `step`). There is a default implementation for this method, which is to return
36- `external_transition.θ`.
37-
38- !!! note
39- In a future breaking release of Turing, this is likely to change to
40- `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method.
41- `Turing.Inference.getparams` is technically an internal method, so the aim here is to
42- unify the interface for samplers at a higher level.
35+ documented in AbstractMCMC.jl). This function must return a tuple of two elements, a
36+ 'transition' and a 'state'.
37+
38+ - `AbstractMCMC.getparams(external_state)`: How to extract the parameters from the **state**
39+ returned by your sampler (i.e., the **second** return value of `step`). For your sampler
40+ to work with Turing.jl, this function should return a Vector of parameter values. Note that
41+ this function does not need to perform any linking or unlinking; Turing.jl will take care of
42+ this for you. You should return the parameters *exactly* as your sampler sees them.
43+
44+ - `AbstractMCMC.getstats(external_state)`: Extract sampler statistics corresponding to this
45+ iteration from the **state** returned by your sampler (i.e., the **second** return value
46+ of `step`). For your sampler to work with Turing.jl, this function should return a
47+ `NamedTuple`. If there are no statistics to return, return `NamedTuple()`.
48+
49+ Note that `getstats` should not include log-probabilities as these will be recalculated by
50+ Turing automatically for you.
51+
52+ Notice that both of these functions take the **state** as input, not the **transition**. In
53+ other words, the transition is completely useless for the external sampler interface. This is
54+ in line with long-term plans for removing transitions from AbstractMCMC.jl and only using
55+ states.
4356
4457There are a few more optional functions which you can implement to improve the integration
4558with Turing.jl:
4659
47- - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as
48- a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
49-
50- - `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires
60+ - `AbstractMCMC.requires_unconstrained_space(::MySampler)`: If your sampler requires
5161 unconstrained space, you should return `true`. This tells Turing to perform linking on the
5262 VarInfo before evaluation, and ensures that the parameter values passed to your sampler
5363 will always be in unconstrained (Euclidean) space.
64+
65+ - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want to disallow your sampler
66+ from a component in Turing's Gibbs sampler, you should make this evaluate to `false`. Note
67+ that the default is `true`, so you should only need to implement this in special cases.
5468"""
55- struct ExternalSampler{S<: AbstractSampler ,AD<: ADTypes.AbstractADType ,Unconstrained } < :
69+ struct ExternalSampler{Unconstrained, S<: AbstractSampler ,AD<: ADTypes.AbstractADType } < :
5670 AbstractSampler
5771 " the sampler to wrap"
5872 sampler:: S
@@ -67,33 +81,20 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
6781 # Arguments
6882 - `sampler::AbstractSampler`: The sampler to wrap.
6983 - `adtype::ADTypes.AbstractADType`: The automatic differentiation (AD) backend to use.
70- - `unconstrained::Val=Val{true}() `: Value type containing a boolean indicating whether the sampler requires unconstrained space.
84+ - `unconstrained::Val`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
7185 """
7286 function ExternalSampler (
73- sampler:: AbstractSampler ,
74- adtype:: ADTypes.AbstractADType ,
75- (:: Val{unconstrained} )= Val (true ),
87+ sampler:: AbstractSampler , adtype:: ADTypes.AbstractADType , :: Val{unconstrained}
7688 ) where {unconstrained}
7789 if ! (unconstrained isa Bool)
7890 throw (
7991 ArgumentError (" Expected Val{true} or Val{false}, got Val{$unconstrained }" )
8092 )
8193 end
82- return new {typeof(sampler),typeof(adtype),unconstrained } (sampler, adtype)
94+ return new {unconstrained, typeof(sampler),typeof(adtype)} (sampler, adtype)
8395 end
8496end
8597
86- """
87- requires_unconstrained_space(sampler::ExternalSampler)
88-
89- Return `true` if the sampler requires unconstrained space, and `false` otherwise.
90- """
91- function requires_unconstrained_space (
92- :: ExternalSampler{<:Any,<:Any,Unconstrained}
93- ) where {Unconstrained}
94- return Unconstrained
95- end
96-
9798"""
9899 externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true)
99100
@@ -106,10 +107,10 @@ Wrap a sampler so it can be used as an inference algorithm.
106107- `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use.
107108- `unconstrained::Bool=true`: Whether the sampler requires unconstrained space.
108109"""
109- function externalsampler (
110- sampler :: AbstractSampler ; adtype = Turing . DEFAULT_ADTYPE, unconstrained :: Bool = true
111- )
112- return ExternalSampler (sampler, adtype, Val (unconstrained) )
110+ function externalsampler (sampler :: AbstractSampler ; adtype = Turing . DEFAULT_ADTYPE)
111+ return ExternalSampler (
112+ sampler, adtype, Val (AbstractMCMC . requires_unconstrained_space (sampler) )
113+ )
113114end
114115
115116# TODO (penelopeysm): Can't we clean this up somehow?
@@ -128,30 +129,22 @@ end
128129get_varinfo (state:: TuringState ) = state. varinfo
129130get_varinfo (state:: AbstractVarInfo ) = state
130131
131- getparams (:: DynamicPPL.Model , transition:: AdvancedHMC.Transition ) = transition. z. θ
132- function getparams (model:: DynamicPPL.Model , state:: AdvancedHMC.HMCState )
133- return getparams (model, state. transition)
134- end
135- getstats (transition:: AdvancedHMC.Transition ) = transition. stat
136-
137- getparams (:: DynamicPPL.Model , transition:: AdvancedMH.Transition ) = transition. params
138-
139132# TODO : Do we also support `resume`, etc?
140133function AbstractMCMC. step (
141134 rng:: Random.AbstractRNG ,
142135 model:: DynamicPPL.Model ,
143- sampler_wrapper:: ExternalSampler ;
136+ sampler_wrapper:: ExternalSampler{unconstrained} ;
144137 initial_state= nothing ,
145138 initial_params, # passed through from sample
146139 kwargs... ,
147- )
140+ ) where {unconstrained}
148141 sampler = sampler_wrapper. sampler
149142
150143 # Initialise varinfo with initial params and link the varinfo if needed.
151144 varinfo = DynamicPPL. VarInfo (model)
152145 _, varinfo = DynamicPPL. init!! (rng, model, varinfo, initial_params)
153146
154- if requires_unconstrained_space (sampler_wrapper)
147+ if unconstrained
155148 varinfo = DynamicPPL. link (varinfo, model)
156149 end
157150
@@ -166,16 +159,17 @@ function AbstractMCMC.step(
166159 )
167160
168161 # Then just call `AbstractMCMC.step` with the right arguments.
169- if initial_state === nothing
170- transition_inner, state_inner = AbstractMCMC. step (
162+ _, state_inner = if initial_state === nothing
163+ AbstractMCMC. step (
171164 rng,
172165 AbstractMCMC. LogDensityModel (f),
173166 sampler;
174167 initial_params= initial_params_vector,
175168 kwargs... ,
176169 )
170+
177171 else
178- transition_inner, state_inner = AbstractMCMC. step (
172+ AbstractMCMC. step (
179173 rng,
180174 AbstractMCMC. LogDensityModel (f),
181175 sampler,
@@ -185,13 +179,12 @@ function AbstractMCMC.step(
185179 )
186180 end
187181
188- # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
189- # The latter uses the state rather than the transition.
190- # TODO (penelopeysm): Make this use AbstractMCMC.getparams instead
191- new_parameters = Turing. Inference. getparams (f. model, transition_inner)
182+ new_parameters = AbstractMCMC. getparams (f. model, state_inner)
192183 new_vi = DynamicPPL. unflatten (f. varinfo, new_parameters)
184+ new_stats = AbstractMCMC. getstats (state_inner)
193185 return (
194- Transition (f. model, new_vi, transition_inner), TuringState (state_inner, new_vi, f)
186+ Turing. Inference. Transition (f. model, new_vi, new_stats),
187+ TuringState (state_inner, new_vi, f),
195188 )
196189end
197190
@@ -206,16 +199,22 @@ function AbstractMCMC.step(
206199 f = state. ldf
207200
208201 # Then just call `AdvancedMCMC.step` with the right arguments.
209- transition_inner , state_inner = AbstractMCMC. step (
202+ _ , state_inner = AbstractMCMC. step (
210203 rng, AbstractMCMC. LogDensityModel (f), sampler, state. state; kwargs...
211204 )
212205
213- # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
214- # The latter uses the state rather than the transition.
215- # TODO (penelopeysm): Make this use AbstractMCMC.getparams instead
216- new_parameters = Turing. Inference. getparams (f. model, transition_inner)
206+ new_parameters = AbstractMCMC. getparams (f. model, state_inner)
217207 new_vi = DynamicPPL. unflatten (f. varinfo, new_parameters)
208+ new_stats = AbstractMCMC. getstats (state_inner)
218209 return (
219- Transition (f. model, new_vi, transition_inner), TuringState (state_inner, new_vi, f)
210+ Turing. Inference. Transition (f. model, new_vi, new_stats),
211+ TuringState (state_inner, new_vi, f),
220212 )
221213end
214+
215+ # Implementation of interface for AdvancedMH and AdvancedHMC. TODO : These should be
216+ # upstreamed to the respective packages, I'm just not doing it here to avoid having to run
217+ # CI against three separate PR branches.
218+ AbstractMCMC. getstats (state:: AdvancedHMC.HMCState ) = state. transition. stat
219+ # Note that for AdvancedMH, transition and state are equivalent (and both named Transition)
220+ AbstractMCMC. getstats (state:: AdvancedMH.Transition ) = (accepted= state. accepted,)
0 commit comments