Skip to content

Commit 06752c4

Browse files
committed
Decouple external sampler interface from Turing
1 parent be007f3 commit 06752c4

File tree

5 files changed

+91
-84
lines changed

5 files changed

+91
-84
lines changed

HISTORY.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
# 0.42.0
22

3+
## External sampler interface
4+
5+
The interface for defining an external sampler has been reworked.
6+
In general, implementations of external samplers should now no longer need to depend on Turing.
7+
This is because the interface functions required have been shifted upstream to AbstractMCMC.jl.
8+
9+
In particular, you now only need to define the following functions:
10+
11+
- AbstractMCMC.step(rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, ::MySampler; kwargs...) (and also a method with `state`, and the corresponding `step_warmup` methods if needed)
12+
- AbstractMCMC.getparams(::MySamplerState) -> Vector{<:Real}
13+
- AbstractMCMC.getstats(::MySamplerState) -> NamedTuple
14+
- AbstractMCMC.requires_unconstrained_space(::MySampler) -> Bool (default `true`)
15+
16+
This means that you only need to depend on AbstractMCMC.jl.
17+
As long as the above functions are defined correctly, Turing will be able to use your external sampler.
18+
319
# 0.41.0
420

521
## DynamicPPL 0.38

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ TuringOptimExt = ["Optim", "AbstractPPL"]
4949

5050
[compat]
5151
ADTypes = "1.9"
52-
AbstractMCMC = "5.5"
52+
AbstractMCMC = "5.9"
5353
AbstractPPL = "0.11, 0.12, 0.13"
5454
Accessors = "0.1"
5555
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"

src/mcmc/external_sampler.jl

Lines changed: 67 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained}
2+
ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType}
33
44
Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng,
55
::DynamicPPL.Model, spl)`.
@@ -14,45 +14,59 @@ $(TYPEDFIELDS)
1414
If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl
1515
models, there are two options:
1616
17-
1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the
18-
most powerful option and is what Turing.jl's in-house samplers do. Implementing this
19-
means that you can directly call `sample(model, MySampler(), N)`.
17+
1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. That is to
18+
say, implement `AbstractMCMC.step(rng::Random.AbstractRNG, model::DynamicPPL.Model,
19+
sampler::MySampler; kwargs...)` and related methods. This is the most powerful option and
20+
is what Turing.jl's in-house samplers do. Implementing this means that you can directly
21+
call `sample(model, MySampler(), N)`.
2022
21-
2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This
22-
struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
23+
2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel` (the
24+
same signature as above except that `model::AbstractMCMC.LogDensityModel`). This struct
25+
wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
2326
implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use
2427
this with Turing.jl, you will need to wrap your sampler: `sample(model,
2528
externalsampler(MySampler()), N)`.
2629
2730
This section describes the latter.
2831
29-
`MySampler` must implement the following methods:
32+
`MySampler` **must** implement the following methods:
3033
3134
- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is
32-
documented in AbstractMCMC.jl)
33-
- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the
34-
parameters from the transition returned by your sampler (i.e., the first return value of
35-
`step`). There is a default implementation for this method, which is to return
36-
`external_transition.θ`.
37-
38-
!!! note
39-
In a future breaking release of Turing, this is likely to change to
40-
`AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method.
41-
`Turing.Inference.getparams` is technically an internal method, so the aim here is to
42-
unify the interface for samplers at a higher level.
35+
documented in AbstractMCMC.jl). This function must return a tuple of two elements, a
36+
'transition' and a 'state'.
37+
38+
- `AbstractMCMC.getparams(external_state)`: How to extract the parameters from the **state**
39+
returned by your sampler (i.e., the **second** return value of `step`). For your sampler
40+
to work with Turing.jl, this function should return a Vector of parameter values. Note that
41+
this function does not need to perform any linking or unlinking; Turing.jl will take care of
42+
this for you. You should return the parameters *exactly* as your sampler sees them.
43+
44+
- `AbstractMCMC.getstats(external_state)`: Extract sampler statistics corresponding to this
45+
iteration from the **state** returned by your sampler (i.e., the **second** return value
46+
of `step`). For your sampler to work with Turing.jl, this function should return a
47+
`NamedTuple`. If there are no statistics to return, return `NamedTuple()`.
48+
49+
Note that `getstats` should not include log-probabilities as these will be recalculated by
50+
Turing automatically for you.
51+
52+
Notice that both of these functions take the **state** as input, not the **transition**. In
53+
other words, the transition is completely useless for the external sampler interface. This is
54+
in line with long-term plans for removing transitions from AbstractMCMC.jl and only using
55+
states.
4356
4457
There are a few more optional functions which you can implement to improve the integration
4558
with Turing.jl:
4659
47-
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as
48-
a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
49-
50-
- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires
60+
- `AbstractMCMC.requires_unconstrained_space(::MySampler)`: If your sampler requires
5161
unconstrained space, you should return `true`. This tells Turing to perform linking on the
5262
VarInfo before evaluation, and ensures that the parameter values passed to your sampler
5363
will always be in unconstrained (Euclidean) space.
64+
65+
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want to disallow your sampler
66+
from a component in Turing's Gibbs sampler, you should make this evaluate to `false`. Note
67+
that the default is `true`, so you should only need to implement this in special cases.
5468
"""
55-
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
69+
struct ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType} <:
5670
AbstractSampler
5771
"the sampler to wrap"
5872
sampler::S
@@ -67,33 +81,20 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
6781
# Arguments
6882
- `sampler::AbstractSampler`: The sampler to wrap.
6983
- `adtype::ADTypes.AbstractADType`: The automatic differentiation (AD) backend to use.
70-
- `unconstrained::Val=Val{true}()`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
84+
- `unconstrained::Val`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
7185
"""
7286
function ExternalSampler(
73-
sampler::AbstractSampler,
74-
adtype::ADTypes.AbstractADType,
75-
(::Val{unconstrained})=Val(true),
87+
sampler::AbstractSampler, adtype::ADTypes.AbstractADType, ::Val{unconstrained}
7688
) where {unconstrained}
7789
if !(unconstrained isa Bool)
7890
throw(
7991
ArgumentError("Expected Val{true} or Val{false}, got Val{$unconstrained}")
8092
)
8193
end
82-
return new{typeof(sampler),typeof(adtype),unconstrained}(sampler, adtype)
94+
return new{unconstrained,typeof(sampler),typeof(adtype)}(sampler, adtype)
8395
end
8496
end
8597

86-
"""
87-
requires_unconstrained_space(sampler::ExternalSampler)
88-
89-
Return `true` if the sampler requires unconstrained space, and `false` otherwise.
90-
"""
91-
function requires_unconstrained_space(
92-
::ExternalSampler{<:Any,<:Any,Unconstrained}
93-
) where {Unconstrained}
94-
return Unconstrained
95-
end
96-
9798
"""
9899
externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true)
99100
@@ -106,10 +107,10 @@ Wrap a sampler so it can be used as an inference algorithm.
106107
- `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use.
107108
- `unconstrained::Bool=true`: Whether the sampler requires unconstrained space.
108109
"""
109-
function externalsampler(
110-
sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained::Bool=true
111-
)
112-
return ExternalSampler(sampler, adtype, Val(unconstrained))
110+
function externalsampler(sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE)
111+
return ExternalSampler(
112+
sampler, adtype, Val(AbstractMCMC.requires_unconstrained_space(sampler))
113+
)
113114
end
114115

115116
# TODO(penelopeysm): Can't we clean this up somehow?
@@ -128,30 +129,22 @@ end
128129
get_varinfo(state::TuringState) = state.varinfo
129130
get_varinfo(state::AbstractVarInfo) = state
130131

131-
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
132-
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
133-
return getparams(model, state.transition)
134-
end
135-
getstats(transition::AdvancedHMC.Transition) = transition.stat
136-
137-
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
138-
139132
# TODO: Do we also support `resume`, etc?
140133
function AbstractMCMC.step(
141134
rng::Random.AbstractRNG,
142135
model::DynamicPPL.Model,
143-
sampler_wrapper::ExternalSampler;
136+
sampler_wrapper::ExternalSampler{unconstrained};
144137
initial_state=nothing,
145138
initial_params, # passed through from sample
146139
kwargs...,
147-
)
140+
) where {unconstrained}
148141
sampler = sampler_wrapper.sampler
149142

150143
# Initialise varinfo with initial params and link the varinfo if needed.
151144
varinfo = DynamicPPL.VarInfo(model)
152145
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params)
153146

154-
if requires_unconstrained_space(sampler_wrapper)
147+
if unconstrained
155148
varinfo = DynamicPPL.link(varinfo, model)
156149
end
157150

@@ -166,16 +159,17 @@ function AbstractMCMC.step(
166159
)
167160

168161
# Then just call `AbstractMCMC.step` with the right arguments.
169-
if initial_state === nothing
170-
transition_inner, state_inner = AbstractMCMC.step(
162+
_, state_inner = if initial_state === nothing
163+
AbstractMCMC.step(
171164
rng,
172165
AbstractMCMC.LogDensityModel(f),
173166
sampler;
174167
initial_params=initial_params_vector,
175168
kwargs...,
176169
)
170+
177171
else
178-
transition_inner, state_inner = AbstractMCMC.step(
172+
AbstractMCMC.step(
179173
rng,
180174
AbstractMCMC.LogDensityModel(f),
181175
sampler,
@@ -185,13 +179,12 @@ function AbstractMCMC.step(
185179
)
186180
end
187181

188-
# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
189-
# The latter uses the state rather than the transition.
190-
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
191-
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
182+
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
192183
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
184+
new_stats = AbstractMCMC.getstats(state_inner)
193185
return (
194-
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
186+
Turing.Inference.Transition(f.model, new_vi, new_stats),
187+
TuringState(state_inner, new_vi, f),
195188
)
196189
end
197190

@@ -206,16 +199,22 @@ function AbstractMCMC.step(
206199
f = state.ldf
207200

208201
# Then just call `AdvancedMCMC.step` with the right arguments.
209-
transition_inner, state_inner = AbstractMCMC.step(
202+
_, state_inner = AbstractMCMC.step(
210203
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
211204
)
212205

213-
# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
214-
# The latter uses the state rather than the transition.
215-
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
216-
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
206+
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
217207
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
208+
new_stats = AbstractMCMC.getstats(state_inner)
218209
return (
219-
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
210+
Turing.Inference.Transition(f.model, new_vi, new_stats),
211+
TuringState(state_inner, new_vi, f),
220212
)
221213
end
214+
215+
# Implementation of interface for AdvancedMH and AdvancedHMC. TODO: These should be
216+
# upstreamed to the respective packages, I'm just not doing it here to avoid having to run
217+
# CI against three separate PR branches.
218+
AbstractMCMC.getstats(state::AdvancedHMC.HMCState) = state.transition.stat
219+
# Note that for AdvancedMH, transition and state are equivalent (and both named Transition)
220+
AbstractMCMC.getstats(state::AdvancedMH.Transition) = (accepted=state.accepted,)

src/mcmc/gibbs.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
44
Return a boolean indicating whether `spl` is a valid component for a Gibbs sampler.
55
6-
Defaults to `false` if no method has been defined for a particular algorithm type.
6+
Defaults to `true` if no method has been defined for a particular sampler.
77
"""
8-
isgibbscomponent(::AbstractSampler) = false
8+
isgibbscomponent(::AbstractSampler) = true
99

1010
isgibbscomponent(::ESS) = true
1111
isgibbscomponent(::HMC) = true
@@ -15,11 +15,7 @@ isgibbscomponent(::MH) = true
1515
isgibbscomponent(::PG) = true
1616

1717
isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler)
18-
1918
isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
20-
isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
21-
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
22-
isgibbscomponent(spl) = false
2319

2420
function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
2521
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf

test/mcmc/external_sampler.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@ using Turing.Inference: AdvancedHMC
2020
# Turing declares an interface for external samplers (see docstring for
2121
# ExternalSampler). We should check that implementing this interface
2222
# and only this interface allows us to use the sampler in Turing.
23-
struct MyTransition{V<:AbstractVector}
24-
params::V
25-
end
26-
# Samplers need to implement `Turing.Inference.getparams`.
27-
Turing.Inference.getparams(::DynamicPPL.Model, t::MyTransition) = t.params
28-
# State doesn't matter (but we need to carry the params through to the next
29-
# iteration).
3023
struct MyState{V<:AbstractVector}
3124
params::V
3225
end
26+
AbstractMCMC.getparams(s::MyState) = s.params
27+
AbstractMCMC.getstats(s::MyState) = (param_length=length(s.params),)
3328

3429
# externalsamplers must accept LogDensityModel inside their step function.
3530
# By default Turing gives the externalsampler a LDF constructed with
@@ -58,7 +53,7 @@ using Turing.Inference: AdvancedHMC
5853
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, initial_params)
5954
@test lp isa Real
6055
@test grad isa AbstractVector{<:Real}
61-
return MyTransition(initial_params), MyState(initial_params)
56+
return nothing, MyState(initial_params)
6257
end
6358
function AbstractMCMC.step(
6459
rng::Random.AbstractRNG,
@@ -75,7 +70,7 @@ using Turing.Inference: AdvancedHMC
7570
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
7671
@test lp isa Real
7772
@test grad isa AbstractVector{<:Real}
78-
return MyTransition(params), MyState(params)
73+
return nothing, MyState(params)
7974
end
8075

8176
@model function test_external_sampler()
@@ -96,6 +91,7 @@ using Turing.Inference: AdvancedHMC
9691
@test all(chn[:lp] .== expected_logpdf)
9792
@test all(chn[:logprior] .== expected_logpdf)
9893
@test all(chn[:loglikelihood] .== 0.0)
94+
@test all(chn[:param_length] .== 2)
9995
end
10096

10197
function initialize_nuts(model::DynamicPPL.Model)

0 commit comments

Comments
 (0)