Skip to content

Commit 1e972aa

Browse files
committed
redo of TuringLang#41
1 parent 35d310e commit 1e972aa

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/mh-core.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,41 @@ function AbstractMCMC.step(
191191
return transition, transition
192192
end
193193

194+
"""
195+
is_symmetric_proposal(proposal::P) where P
196+
197+
# Example:
198+
199+
```julia
200+
using Distributions, AdvancedMH
201+
202+
# Model definition.
203+
model = DensityModel(s -> logpdf(Normal(), s.x) + logpdf(Normal(5,.7), s.y))
204+
205+
# Set up the proposal.
206+
p = (x=RandomWalkProposal(Normal(0,.5)), y=RandomWalkProposal(Normal(0,.5)))
207+
208+
# Implementing this will skip the computation of the Hastings ratio.
209+
AdvancedMH.is_symmetric_proposal(proposal::typeof(p)) = true
210+
211+
# Sample from the posterior with initial parameters.
212+
chain = sample(m1, MetropolisHastings(p), 100000; chain_type=Vector{NamedTuple})
213+
```
214+
"""
215+
is_symmetric_proposal(proposal::P) where P = false
216+
217+
# The following univariate random walk proposals are symmetric.
218+
is_symmetric_proposal(proposal::RandomWalkProposal{<:Normal}) = true
219+
is_symmetric_proposal(proposal::RandomWalkProposal{<:MvNormal}) = true
220+
is_symmetric_proposal(proposal::RandomWalkProposal{<:TDist}) = true
221+
is_symmetric_proposal(proposal::RandomWalkProposal{<:Cauchy}) = true
222+
223+
# The following multivariate random walk proposals are symmetric.
224+
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:Normal}}) = true
225+
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:MvNormal}}) = true
226+
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:TDist}}) = true
227+
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:Cauchy}}) = true
228+
194229
# Define the other sampling steps.
195230
# Return a 2-tuple consisting of the next sample and the the next state.
196231
# In this case they are identical, and either a new proposal (if accepted)
@@ -206,8 +241,10 @@ function AbstractMCMC.step(
206241
params = propose(rng, spl, model, params_prev)
207242

208243
# Calculate the log acceptance probability.
209-
logα = logdensity(model, params) - logdensity(model, params_prev) +
210-
q(spl, params_prev, params) - q(spl, params, params_prev)
244+
logα = logdensity(model, params) - logdensity(model, params_prev)
245+
if is_symmetric_proposal(spl.proposal)
246+
logα += q(spl, params_prev, params) - q(spl, params, params_prev)
247+
end
211248

212249
# Decide whether to return the previous params or the new one.
213250
if -Random.randexp(rng) < logα

test/runtests.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,24 @@ using ForwardDiff
104104

105105
@test chain1[1].params == val
106106
end
107-
107+
108+
@testset "is_symmetric_proposal" begin
109+
# Model definition.
110+
m1 = DensityModel(s -> logpdf(Normal(), s.x) + logpdf(Normal(5,.7), s.y))
111+
112+
# Set up the proposal.
113+
p1 = (x=RandomWalkProposal(Normal(0,.5)), y=RandomWalkProposal(Normal(0,.5)))
114+
AdvancedMH.is_symmetric_proposal(proposal::typeof(p1)) = true
115+
116+
# Sample from the posterior with initial parameters.
117+
chain1 = sample(m1, MetropolisHastings(p1), 100000; chain_type=Vector{NamedTuple})
118+
119+
@test mean(getindex.(chain1, :x)) 0 atol=0.05
120+
@test mean(getindex.(chain1, :y)) 5 atol=0.05
121+
@test std(getindex.(chain1, :x)) 1 atol=0.05
122+
@test std(getindex.(chain1, :y)) .7 atol=0.05
123+
end
124+
108125
@testset "MALA" begin
109126

110127
# Set up the sampler.

0 commit comments

Comments
 (0)