Skip to content

Commit

Permalink
Implementing and correcting the base set of swap strategies, plus wra…
Browse files Browse the repository at this point in the history
…pping swap moves as part of steps (#143)

* Implementing and correcting base set of swap strategies
wrapping up swap move with step to guarantee number of samples requested == number of samples returned
cleaned up and generalised swapping.jl to allow non-neighbouring chain swaps

* reverting subscripts on i and j

* Update src/stepping.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Reverting sampler.adapt passing and swap step merge

* Reverting swap_every

* disallow random swaps and adaptation

* fix adaptation error in swapping.jl

* Fixing runtests.jl

* Fixing runtests, rejigging params a bit to make the turing ones more reasonable

* changing should_swap so that the first move is counted in the total and swaps happen regularly

* Small changes to runtests to remove a deprecation warning

* Update runtests.jl

* Update CI.yml

* Update runtests.jl

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
3 people authored Feb 23, 2023
1 parent a5dbee5 commit dcf7bc2
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.6'
- '1.7'
- '1'
os:
- ubuntu-latest
Expand Down
9 changes: 6 additions & 3 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ export tempered,
tempered_sample,
TemperedSampler,
make_tempered_model,
StandardSwap,
RandomPermutationSwap,
NonReversibleSwap
ReversibleSwap,
NonReversibleSwap,
SingleSwap,
SingleRandomSwap,
RandomSwap,
NoSwap

# TODO: Should we make this trait-based instead?
implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing
Expand Down
9 changes: 6 additions & 3 deletions src/ladders.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
get_scaling_val(N_it, swap_strategy)
get_scaling_val(N_it, <:AbstractSwapStrategy)
Calculates a scaling factor for polynomial step size between inverse temperatures.
"""
get_scaling_val(N_it, ::StandardSwap) = N_it - 1
get_scaling_val(N_it, ::ReversibleSwap) = 2
get_scaling_val(N_it, ::NonReversibleSwap) = 2
get_scaling_val(N_it, ::RandomPermutationSwap) = 1
get_scaling_val(N_it, ::SingleSwap) = N_it - 1
get_scaling_val(N_it, ::SingleRandomSwap) = N_it - 1
get_scaling_val(N_it, ::RandomSwap) = 1
get_scaling_val(N_it, ::NoSwap) = N_it - 1

"""
generate_inverse_temperatures(N_it, swap_strategy)
Expand Down
24 changes: 14 additions & 10 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
"""
tempered(sampler, inverse_temperatures; kwargs...)
OR
tempered(sampler, N_it; swap_strategy=StandardSwap(), kwargs...)
tempered(sampler, N_it; swap_strategy=ReversibleSwap(), kwargs...)
Return a tempered version of `sampler` using the provided `inverse_temperatures` or
inverse temperatures generated from `N_it` and the `swap_strategy`.
Expand All @@ -72,30 +72,33 @@ inverse temperatures generated from `N_it` and the `swap_strategy`.
- `N_it`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures`
# Keyword arguments
- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made
- `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains
- `swap_every::Integer` steps are carried out between each attempt at a swap
# See also
- [`TemperedSampler`](@ref)
- For more on the swap strategies:
- [`AbstractSwapStrategy`](@ref)
- [`StandardSwap`](@ref)
- [`NonReversibleSwap`](@ref)
- [`RandomPermutationSwap`](@ref)
- [`AbstractSwapStrategy`](@ref)
- [`ReversibleSwap`](@ref)
- [`NonReversibleSwap`](@ref)
- [`SingleSwap`](@ref)
- [`SingleRandomSwap`](@ref)
- [`RandomSwap`](@ref)
- [`NoSwap`](@ref)
"""
function tempered(
sampler::AbstractMCMC.AbstractSampler,
N_it::Integer;
swap_strategy::AbstractSwapStrategy=StandardSwap(),
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
kwargs...
)
return tempered(sampler, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy = swap_strategy, kwargs...)
end
function tempered(
sampler::AbstractMCMC.AbstractSampler,
inverse_temperatures::Vector{<:Real};
swap_strategy::AbstractSwapStrategy=StandardSwap(),
swap_every::Integer=2,
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
swap_every::Integer=10,
adapt::Bool=false,
adapt_target::Real=0.234,
adapt_stepsize::Real=1,
Expand All @@ -104,7 +107,8 @@ function tempered(
adapt_scale=defaultscale(adapt_schedule, inverse_temperatures),
kwargs...
)
swap_every >= 2 || error("This must be a positive integer greater than 1.")
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
swap_every > 1 || error("`swap_every` must take a positive integer value greater than 1.")
inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
adaptation_states = init_adaptation(
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
Expand Down
17 changes: 10 additions & 7 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
"""
tempered_sample([rng, ], model, sampler, N, inverse_temperatures; kwargs...)
OR
tempered_sample([rng, ], model, sampler, N, N_it; swap_strategy=StandardSwap(), kwargs...)
tempered_sample([rng, ], model, sampler, N, N_it; swap_strategy=SingleSwap(), kwargs...)
Generate `N` samples from `model` using a tempered version of the provided `sampler`.
Provide either `inverse_temperatures` or `N_it` (and a `swap_strategy`) to generate some
# Keyword arguments
- `N_burnin::Integer` burn-in steps will be carried out before any swapping between chains is attempted
- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made
- `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains
- `swap_every::Integer` steps are carried out between each attempt at a swap
# See also
- [`tempered`](@ref)
- [`TemperedSampler`](@ref)
- For more on the swap strategies:
- [`AbstractSwapStrategy`](@ref)
- [`StandardSwap`](@ref)
- [`NonReversibleSwap`](@ref)
- [`RandomPermutationSwap`](@ref)
- [`AbstractSwapStrategy`](@ref)
- [`ReversibleSwap`](@ref)
- [`NonReversibleSwap`](@ref)
- [`SingleSwap`](@ref)
- [`SingleRandomSwap`](@ref)
- [`RandomSwap`](@ref)
- [`NoSwap`](@ref)
"""
function tempered_sample(
model,
Expand All @@ -36,7 +39,7 @@ function tempered_sample(
sampler::AbstractMCMC.AbstractSampler,
N::Integer,
N_it::Integer;
swap_strategy::AbstractSwapStrategy = StandardSwap(),
swap_strategy::AbstractSwapStrategy = SingleSwap(),
kwargs...
)
return tempered_sample(model, sampler, N, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy=swap_strategy, kwargs...)
Expand Down
91 changes: 62 additions & 29 deletions src/stepping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Return `true` if a swap should happen at this iteration, and `false` otherwise.
"""
function should_swap(sampler::TemperedSampler, state::TemperedState)
return state.total_steps % sampler.swap_every == 0
return state.total_steps % sampler.swap_every == 1
end

get_init_params(x, _)= x
Expand Down Expand Up @@ -86,7 +86,7 @@ function AbstractMCMC.step(
state::TemperedState;
kwargs...
)
# Reset.
# Reset state
@set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios)

if should_swap(sampler, state)
Expand Down Expand Up @@ -134,7 +134,7 @@ end
"""
swap_step([strategy::AbstractSwapStrategy, ]rng, model, sampler, state)
Return new `state`, now with temperatures swapped according to `strategy`.
Return a new `state`, with temperatures possibly swapped according to `strategy`.
If no `strategy` is provided, the return-value of [`swapstrategy`](@ref) called on `sampler`
is used.
Expand All @@ -149,56 +149,89 @@ function swap_step(
end

function swap_step(
strategy::StandardSwap,
strategy::ReversibleSwap,
rng::Random.AbstractRNG,
model,
sampler::TemperedSampler,
state::TemperedState
)
L = numtemps(sampler) - 1
k = rand(rng, 1:L)
return swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps / L)
# Randomly select whether to attempt swaps between chains
# corresponding to odd or even indices of the temperature ladder
odd = rand([true, false])
for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))]
state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt)
end
return state
end

function swap_step(
strategy::RandomPermutationSwap,
strategy::NonReversibleSwap,
rng::Random.AbstractRNG,
model,
sampler::TemperedSampler,
state::TemperedState
)
L = numtemps(sampler) - 1
levels = Vector{Int}(undef, L)
Random.randperm!(rng, levels)

# Iterate through all levels and attempt swaps.
for k in levels
state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps)
# Alternate between attempting to swap chains corresponding
# to odd and even indices of the temperature ladder
odd = state.total_steps % (2 * sampler.swap_every) != 0
for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))]
state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt)
end
return state
end

function swap_step(
strategy::NonReversibleSwap,
strategy::SingleSwap,
rng::Random.AbstractRNG,
model,
sampler::TemperedSampler,
state::TemperedState
)
L = numtemps(sampler) - 1
# Alternate between swapping odds and evens.
levels = if state.total_steps % (2 * sampler.swap_every) == 0
1:2:L
else
2:2:L
end
# Randomly pick one index `k` of the temperature ladder and
# attempt a swap between the corresponding chain and its neighbour
k = rand(rng, 1:(numtemps(sampler) - 1))
return swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt)
end

function swap_step(
strategy::SingleRandomSwap,
rng::Random.AbstractRNG,
model,
sampler::TemperedSampler,
state::TemperedState
)
# Randomly pick two temperature ladder indices in order to
# attempt a swap between the corresponding chains
chains = Set(1:numtemps(sampler))
i = pop!(chains, rand(rng, chains))
j = pop!(chains, rand(rng, chains))
return swap_attempt(rng, model, sampler, state, i, j, sampler.adapt)
end

# Iterate through all levels and attempt swaps.
for k in levels
# TODO: For this swapping strategy, we should really be using the adaptation from Syed et. al. (2019),
# but with the current one: shouldn't we at least divide `state.total_steps` by 2 since it will
# take use two swap-attempts before we have tried swapping all of them (in expectation).
state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps)
function swap_step(
strategy::RandomSwap,
rng::Random.AbstractRNG,
model,
sampler::TemperedSampler,
state::TemperedState
)
# Iterate through all of temperature ladder indices, picking random
# pairs and attempting swaps between the corresponding chains
chains = Set(1:numtemps(sampler))
while length(chains) >= 2
i = pop!(chains, rand(rng, chains))
j = pop!(chains, rand(rng, chains))
state = swap_attempt(rng, model, sampler, state, i, j, sampler.adapt)
end
return state
end

function swap_step(
strategy::NoSwap,
rng::Random.AbstractRNG,
model,
sampler::TemperedSampler,
state::TemperedState
)
return state
end
Loading

0 comments on commit dcf7bc2

Please sign in to comment.