Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Robust Adaptive Metropolis #106

Merged
merged 34 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
85ec534
added an initial implementation of `RAM`
torfjelde Dec 4, 2024
de519a4
added proper docs for RAM
torfjelde Dec 4, 2024
40ebb7e
fixed doctest for `RAM` + added impls of `getparams` and `setparams!!`
torfjelde Dec 4, 2024
2dec18a
added DocStringExtensions as a dep
torfjelde Dec 4, 2024
045f8c5
bump patch version
torfjelde Dec 4, 2024
755a180
attempt at making the dcotest a bit more consistent
torfjelde Dec 4, 2024
5c1c6f5
a
torfjelde Dec 4, 2024
cddf8d1
added checks for eigenvalues according to p. 13 in Vihola (2012) (in
torfjelde Dec 4, 2024
29c9078
fixed default value for `eigenvalue_lower_bound`
torfjelde Dec 5, 2024
78a5f51
applied suggestions from @mhauru
torfjelde Dec 6, 2024
652a227
more doctesting of RAM + improved docstrings
torfjelde Dec 6, 2024
5eaff52
added docstring for `RAMState`
torfjelde Dec 6, 2024
d8688fa
added proper testing of RAM
torfjelde Dec 6, 2024
f5fc301
Update src/RobustAdaptiveMetropolis.jl
torfjelde Dec 6, 2024
56ec717
added compat entries to docs
torfjelde Dec 6, 2024
da431b4
apply suggestions from @devmotion
torfjelde Dec 6, 2024
f2889a0
Merge remote-tracking branch 'origin/torfjelde/RAM' into torfjelde/RAM
torfjelde Dec 6, 2024
9247281
renamed `RAM` to `RobostMetropolisHastings` + removed the separate mo…
Dec 10, 2024
4764120
formatting
Dec 10, 2024
11f3b64
made the docstring for RAM a bit nicer
Dec 10, 2024
df4feb1
fixed doctest
Dec 10, 2024
f784492
formatting
Dec 10, 2024
45820d2
minor improvement to docstring of RAM
Dec 10, 2024
7405a19
fused scalar operations
Dec 10, 2024
5dce265
added dimensionality check of the provided `S` matrix
Dec 10, 2024
5ee44e3
fixed typo
Dec 10, 2024
37a2189
Update docs/src/api.md
torfjelde Dec 10, 2024
5193119
use `randn` instead of `rand` for initialisation
Dec 10, 2024
d4a144e
added an explanation of the `min`
Dec 10, 2024
6295e78
Update test/RobustAdaptiveMetropolis.jl
torfjelde Dec 10, 2024
6f8fda4
use explicit `Cholesky` constructor for backwards compat
Dec 10, 2024
5815a9b
Fix typo: ```` -> ```
mhauru Dec 10, 2024
1b38ca6
formatted according to `blue`
Dec 10, 2024
f426d0d
Update src/RobustAdaptiveMetropolis.jl
torfjelde Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.8.4"
version = "0.8.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand All @@ -26,6 +27,7 @@ AdvancedMHStructArraysExt = "StructArrays"
AbstractMCMC = "5.6"
DiffResults = "1"
Distributions = "0.25"
DocStringExtensions = "0.9"
FillArrays = "1"
ForwardDiff = "0.10"
LinearAlgebra = "1.6"
Expand Down
10 changes: 10 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
devmotion marked this conversation as resolved.
Show resolved Hide resolved

[compat]
Documenter = "1"
Distributions = "0.25"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
MCMCChains = "6.0.4"
Random = "1.6"
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ MetropolisHastings
```@docs
DensityModel
```

## Samplers

```@docs
RAM
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
```
7 changes: 5 additions & 2 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module AdvancedMH
# Import the relevant libraries.
using AbstractMCMC
using Distributions
using LinearAlgebra: I
using LinearAlgebra: LinearAlgebra, I
using FillArrays: Zeros
using DocStringExtensions: FIELDS

using LogDensityProblems: LogDensityProblems

Expand All @@ -22,7 +23,8 @@ export
SymmetricRandomWalkProposal,
Ensemble,
StretchProposal,
MALA
MALA,
RobustAdaptiveMetropolis

# Reexports
export sample, MCMCThreads, MCMCDistributed, MCMCSerial
Expand Down Expand Up @@ -159,5 +161,6 @@ include("proposal.jl")
include("mh-core.jl")
include("emcee.jl")
include("MALA.jl")
include("RobustAdaptiveMetropolis.jl")

end # module AdvancedMH
279 changes: 279 additions & 0 deletions src/RobustAdaptiveMetropolis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# TODO: Should we generalise this arbitrary symmetric proposals?
"""
RobustAdaptiveMetropolis

Robust Adaptive Metropolis-Hastings (RAM).

This is a simple implementation of the RAM algorithm described in [^VIH12].

# Fields

$(FIELDS)

# Examples

The following demonstrates how to implement a simple Gaussian model and sample from it using the RAM algorithm.

```jldoctest ram-gaussian; setup=:(using Random; Random.seed!(1234);)
julia> using AdvancedMH, Distributions, MCMCChains, LogDensityProblems, LinearAlgebra

julia> # Define a Gaussian with zero mean and some covariance.
struct Gaussian{A}
Σ::A
end

julia> # Implement the LogDensityProblems interface.
LogDensityProblems.dimension(model::Gaussian) = size(model.Σ, 1)

julia> function LogDensityProblems.logdensity(model::Gaussian, x)
d = LogDensityProblems.dimension(model)
return logpdf(MvNormal(zeros(d),model.Σ), x)
end

julia> LogDensityProblems.capabilities(::Gaussian) = LogDensityProblems.LogDensityOrder{0}()
devmotion marked this conversation as resolved.
Show resolved Hide resolved

julia> # Construct the model. We'll use a correlation of 0.5.
model = Gaussian([1.0 0.5; 0.5 1.0]);

julia> # Number of samples we want in the resulting chain.
num_samples = 10_000;

julia> # Number of warmup steps, i.e. the number of steps to adapt the covariance of the proposal.
# Note that these are not included in the resulting chain, as `discard_initial=num_warmup`
# by default in the `sample` call. To include them, pass `discard_initial=0` to `sample`.
num_warmup = 10_000;
devmotion marked this conversation as resolved.
Show resolved Hide resolved

julia> # Sample!
chain = sample(
model,
RobustAdaptiveMetropolis(),
num_samples;
chain_type=Chains, num_warmup, progress=false, initial_params=zeros(2)
);

julia> isapprox(cov(Array(chain)), model.Σ; rtol = 0.2)
true
```

It's also possible to restrict the eigenvalues to avoid either too small or too large values. See p. 13 in [^VIH12].

```jldoctest ram-gaussian
julia> chain = sample(
model,
RobustAdaptiveMetropolis(eigenvalue_lower_bound=0.1, eigenvalue_upper_bound=2.0),
num_samples;
chain_type=Chains, num_warmup, progress=false, initial_params=zeros(2)
);

julia> norm(cov(Array(chain)) - [1.0 0.5; 0.5 1.0]) < 0.2
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
true
````
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks like one backtick too many. PS. How do I make a GitHub suggestion for this?

Copy link
Member

@penelopeysm penelopeysm Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

````suggestion
```
````

if you have N backticks surround them with N+1 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I just pushed a fix for this myself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @penelopeysm. :) I should have guessed that the universe would have been organised in such a way that the fix to quadruple backticks would be more quadruple backticks.


# References
[^VIH12]: Vihola (2012) Robust adaptive Metropolis algorithm with coerced acceptance rate, Statistics and computing.
"""
Base.@kwdef struct RobustAdaptiveMetropolis{T,A<:Union{Nothing,AbstractMatrix{T}}} <:
AdvancedMH.MHSampler
"target acceptance rate. Default: 0.234."
α::T = 0.234
"negative exponent of the adaptation decay rate. Default: `0.6`."
γ::T = 0.6
"initial lower-triangular Cholesky factor. If specified, should be convertible into a `LowerTriangular`. Default: `nothing`."
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
S::A = nothing
"lower bound on eigenvalues of the adapted Cholesky factor. Default: `0.0`."
eigenvalue_lower_bound::T = 0.0
"upper bound on eigenvalues of the adapted Cholesky factor. Default: `Inf`."
eigenvalue_upper_bound::T = Inf
end

"""
RobustAdaptiveMetropolisState

State of the Robust Adaptive Metropolis-Hastings (RAM) algorithm.

See also: [`RobustAdaptiveMetropolis`](@ref).

# Fields
$(FIELDS)
"""
struct RobustAdaptiveMetropolisState{T1,L,A,T2,T3}
"current realization of the chain."
x::T1
"log density of `x` under the target model."
logprob::L
"current lower-triangular Cholesky factor."
S::A
"log acceptance ratio of the previous iteration (not necessarily of `x`)."
logα::T2
"current step size for adaptation of `S`."
η::T3
"current iteration."
iteration::Int
"whether the previous iteration was accepted."
isaccept::Bool
end

AbstractMCMC.getparams(state::RobustAdaptiveMetropolisState) = state.x
AbstractMCMC.setparams!!(state::RobustAdaptiveMetropolisState, x) =
RobustAdaptiveMetropolisState(
x,
state.logprob,
state.S,
state.logα,
state.η,
state.iteration,
state.isaccept,
)

function ram_step_inner(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RobustAdaptiveMetropolis,
state::RobustAdaptiveMetropolisState,
)
# This is the initial state.
f = model.logdensity
d = LogDensityProblems.dimension(f)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

# Sample the proposal.
x = state.x
U = randn(rng, eltype(x), d)
x_new = muladd(state.S, U, x)

# Compute the acceptance probability.
lp = state.logprob
lp_new = LogDensityProblems.logdensity(f, x_new)
logα = min(lp_new - lp, zero(lp)) # `min` because we'll use this for updating
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe only bound it in the update of S? It seems at least easier to read if the bounding is kept together with the part of the algorithm where it's actually needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But IMO this makes it a bit strange if we then put the unbounded logα in the resulting state, since this is not the quantity used to update the S 😕

And for the purposes it is user here, it doesn't actually matter if it's bounded or not, right? As in, it's equivalent here, but not equivalent in the S update, hence it seems somewhat natural for me to just do it once and for all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is not the quantity used to update the S

It is, isn't it? The update is just slightly different. Otherwise with the same reasoning you could also argue that one should only store α (or maybe the difference to the targeted α?) since only these are used to update S.

In the end I guess it doesn't matter as it's only used in these two places. It just felt strange conceptually to bound it here, in particular since it seemed you already felt the need to explain this decision with a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But alpha represents the acceptance probability, no? So clamping it like this is technically always what you should do, but most of time we don't because it's unnecessary for sampling according to this probability.

However, if the user wants to actually look at the resulting acceptance probs, then it's a question of: do we want the user to do

mean(exp, getproperty.(states, :logα)))

or

mean(exp, min.(1, getproperty.(states, :logα)))

In my head, the user expects to do the former, not the latter 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see your point. Even though I think neither of the two alternatives is particularly user-friendly, IMO a separate API for acceptance probabilities would be better.

In any case, I think I wouldn't even have commented on this line if the comment # min because we'll use this for updating would not have been there. So maybe just remove the comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but then I'm worried someone might come along later and go "wait, that's not needed; let's just remove this unnecessary min", not realizing that we'll use this for adaptation 😬

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add an @assert logα <= zero(logα) at the top of ram_adapt?

isaccept = Random.randexp(rng) > -logα

return x_new, lp_new, U, logα, isaccept
end

function ram_adapt(
sampler::RobustAdaptiveMetropolis,
state::RobustAdaptiveMetropolisState,
logα::Real,
U::AbstractVector,
)
Δα = exp(logα) - sampler.α
S = state.S
# TODO: Make this configurable by defining a more general path.
η = state.iteration^(-sampler.γ)
ΔS = (η * abs(Δα)) * S * U / LinearAlgebra.norm(U)
# TODO: Maybe do in-place and then have the user extract it with a callback if they really want it.
S_new = if sign(Δα) == 1
# One rank update.
LinearAlgebra.lowrankupdate(LinearAlgebra.Cholesky(S), ΔS).L
else
# One rank downdate.
LinearAlgebra.lowrankdowndate(LinearAlgebra.Cholesky(S), ΔS).L
end
return S_new, η
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RobustAdaptiveMetropolis;
initial_params = nothing,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has JuliaFormatter been run? Just wondering because I thought the style guide was to not have whitespace around kwarg defaults. Not important though, can be left as is.

kwargs...,
)
# This is the initial state.
f = model.logdensity
d = LogDensityProblems.dimension(f)

# Initial parameter state.
T = if initial_params === nothing
eltype(sampler.γ)
else
Base.promote_type(eltype(sampler.γ), eltype(initial_params))
end
x = if initial_params === nothing
rand(rng, T, d)
else
convert(AbstractVector{T}, initial_params)
end
# Initialize the Cholesky factor of the covariance matrix.
S_data = if sampler.S === nothing
LinearAlgebra.diagm(0 => ones(T, d))
else
# Check the dimensionality of the provided `S`.
if size(sampler.S) != (d, d)
throw(ArgumentError("The provided `S` has the wrong dimensionality."))
end
convert(AbstractMatrix{T}, sampler.S)
end
S = LinearAlgebra.LowerTriangular(S_data)

# Construct the initial state.
lp = LogDensityProblems.logdensity(f, x)
state = RobustAdaptiveMetropolisState(x, lp, S, zero(T), 0, 1, true)

return AdvancedMH.Transition(x, lp, true), state
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RobustAdaptiveMetropolis,
state::RobustAdaptiveMetropolisState;
kwargs...,
)
# Take the inner step.
x_new, lp_new, U, logα, isaccept = ram_step_inner(rng, model, sampler, state)
# Accept / reject the proposal.
state_new = RobustAdaptiveMetropolisState(
isaccept ? x_new : state.x,
isaccept ? lp_new : state.logprob,
state.S,
logα,
state.η,
state.iteration + 1,
isaccept,
)
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept),
state_new
end

function valid_eigenvalues(S, lower_bound, upper_bound)
# Short-circuit if the bounds are the default.
(lower_bound == 0 && upper_bound == Inf) && return true
# Note that this is just the diagonal when `S` is triangular.
eigenvals = LinearAlgebra.eigvals(S)
return all(x -> lower_bound <= x <= upper_bound, eigenvals)
end

function AbstractMCMC.step_warmup(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RobustAdaptiveMetropolis,
state::RobustAdaptiveMetropolisState;
kwargs...,
)
# Take the inner step.
x_new, lp_new, U, logα, isaccept = ram_step_inner(rng, model, sampler, state)
# Adapt the proposal.
S_new, η = ram_adapt(sampler, state, logα, U)
# Check that `S_new` has eigenvalues in the desired range.
if !valid_eigenvalues(
S_new,
sampler.eigenvalue_lower_bound,
sampler.eigenvalue_upper_bound,
)
# In this case, we just keep the old `S` (p. 13 in Vihola, 2012).
S_new = state.S
end

# Update state.
state_new = RobustAdaptiveMetropolisState(
isaccept ? x_new : state.x,
isaccept ? lp_new : state.logprob,
S_new,
logα,
η,
state.iteration + 1,
isaccept,
)
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept),
state_new
end
Loading
Loading