Skip to content

Commit

Permalink
redo of TuringLang#41
Browse files Browse the repository at this point in the history
  • Loading branch information
luiarthur committed Dec 8, 2020
1 parent 35d310e commit 1e972aa
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
41 changes: 39 additions & 2 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,41 @@ function AbstractMCMC.step(
return transition, transition
end

"""
is_symmetric_proposal(proposal::P) where P
# Example:
```julia
using Distributions, AdvancedMH
# Model definition.
model = DensityModel(s -> logpdf(Normal(), s.x) + logpdf(Normal(5,.7), s.y))
# Set up the proposal.
p = (x=RandomWalkProposal(Normal(0,.5)), y=RandomWalkProposal(Normal(0,.5)))
# Implementing this will skip the computation of the Hastings ratio.
AdvancedMH.is_symmetric_proposal(proposal::typeof(p)) = true
# Sample from the posterior with initial parameters.
chain = sample(m1, MetropolisHastings(p), 100000; chain_type=Vector{NamedTuple})
```
"""
is_symmetric_proposal(proposal::P) where P = false

# The following univariate random walk proposals are symmetric.
is_symmetric_proposal(proposal::RandomWalkProposal{<:Normal}) = true
is_symmetric_proposal(proposal::RandomWalkProposal{<:MvNormal}) = true
is_symmetric_proposal(proposal::RandomWalkProposal{<:TDist}) = true
is_symmetric_proposal(proposal::RandomWalkProposal{<:Cauchy}) = true

# The following multivariate random walk proposals are symmetric.
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:Normal}}) = true
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:MvNormal}}) = true
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:TDist}}) = true
is_symmetric_proposal(proposal::RandomWalkProposal{AbstractArray{<:Cauchy}}) = true

# Define the other sampling steps.
# Return a 2-tuple consisting of the next sample and the the next state.
# In this case they are identical, and either a new proposal (if accepted)
Expand All @@ -206,8 +241,10 @@ function AbstractMCMC.step(
params = propose(rng, spl, model, params_prev)

# Calculate the log acceptance probability.
logα = logdensity(model, params) - logdensity(model, params_prev) +
q(spl, params_prev, params) - q(spl, params, params_prev)
logα = logdensity(model, params) - logdensity(model, params_prev)
if is_symmetric_proposal(spl.proposal)
logα += q(spl, params_prev, params) - q(spl, params, params_prev)
end

# Decide whether to return the previous params or the new one.
if -Random.randexp(rng) < logα
Expand Down
19 changes: 18 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,24 @@ using ForwardDiff

@test chain1[1].params == val
end


@testset "is_symmetric_proposal" begin
# Model definition.
m1 = DensityModel(s -> logpdf(Normal(), s.x) + logpdf(Normal(5,.7), s.y))

# Set up the proposal.
p1 = (x=RandomWalkProposal(Normal(0,.5)), y=RandomWalkProposal(Normal(0,.5)))
AdvancedMH.is_symmetric_proposal(proposal::typeof(p1)) = true

# Sample from the posterior with initial parameters.
chain1 = sample(m1, MetropolisHastings(p1), 100000; chain_type=Vector{NamedTuple})

@test mean(getindex.(chain1, :x)) 0 atol=0.05
@test mean(getindex.(chain1, :y)) 5 atol=0.05
@test std(getindex.(chain1, :x)) 1 atol=0.05
@test std(getindex.(chain1, :y)) .7 atol=0.05
end

@testset "MALA" begin

# Set up the sampler.
Expand Down

0 comments on commit 1e972aa

Please sign in to comment.