From dcf7bc29c2089d6eface905fb11d526edab63148 Mon Sep 17 00:00:00 2001 From: Harrison Wilde Date: Thu, 23 Feb 2023 10:38:31 +0000 Subject: [PATCH] Implementing and correcting the base set of swap strategies, plus wrapping swap moves as part of steps (#143) * Implementing and correcting base set of swap strategies wrapping up swap move with step to guarantee number of samples requested == number of samples returned cleaned up and generalised swapping.jl to allow non-neighbouring chain swaps * reverting subscripts on i and j * Update src/stepping.jl Co-authored-by: Tor Erlend Fjelde * Reverting sampler.adapt passing and swap step merge * Reverting swap_every * disallow random swaps and adaptation * fix adaptation error in swapping.jl * Fixing runtests.jl * Fixing runtests, rejigging params a bit to make the turing ones more reasonable * changing should_swap so that the first move is counted in the total and swaps happen regularly * Small changes to runtests to remove a deprecation warning * Update runtests.jl * Update CI.yml * Update runtests.jl --------- Co-authored-by: Tor Erlend Fjelde Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- .github/workflows/CI.yml | 2 +- src/MCMCTempering.jl | 9 ++- src/ladders.jl | 9 ++- src/sampler.jl | 24 ++++--- src/sampling.jl | 17 +++-- src/stepping.jl | 91 +++++++++++++++++--------- src/swapping.jl | 134 ++++++++++++++++++++++++--------------- test/runtests.jl | 44 +++++++------ 8 files changed, 206 insertions(+), 124 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 621e0b5..45d546c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.6' + - '1.7' - '1' os: - ubuntu-latest diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 62c91fa..a8a084d 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -27,9 +27,12 @@ export tempered, tempered_sample, TemperedSampler, make_tempered_model, - StandardSwap, - RandomPermutationSwap, - NonReversibleSwap + ReversibleSwap, + NonReversibleSwap, + SingleSwap, + SingleRandomSwap, + RandomSwap, + NoSwap # TODO: Should we make this trait-based instead? implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing diff --git a/src/ladders.jl b/src/ladders.jl index 2cccab5..0ebf615 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -1,11 +1,14 @@ """ - get_scaling_val(N_it, swap_strategy) + get_scaling_val(N_it, <:AbstractSwapStrategy) Calculates a scaling factor for polynomial step size between inverse temperatures. """ -get_scaling_val(N_it, ::StandardSwap) = N_it - 1 +get_scaling_val(N_it, ::ReversibleSwap) = 2 get_scaling_val(N_it, ::NonReversibleSwap) = 2 -get_scaling_val(N_it, ::RandomPermutationSwap) = 1 +get_scaling_val(N_it, ::SingleSwap) = N_it - 1 +get_scaling_val(N_it, ::SingleRandomSwap) = N_it - 1 +get_scaling_val(N_it, ::RandomSwap) = 1 +get_scaling_val(N_it, ::NoSwap) = N_it - 1 """ generate_inverse_temperatures(N_it, swap_strategy) diff --git a/src/sampler.jl b/src/sampler.jl index e719e76..08f157c 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -59,7 +59,7 @@ end """ tempered(sampler, inverse_temperatures; kwargs...) OR - tempered(sampler, N_it; swap_strategy=StandardSwap(), kwargs...) + tempered(sampler, N_it; swap_strategy=ReversibleSwap(), kwargs...) Return a tempered version of `sampler` using the provided `inverse_temperatures` or inverse temperatures generated from `N_it` and the `swap_strategy`. @@ -72,21 +72,24 @@ inverse temperatures generated from `N_it` and the `swap_strategy`. - `N_it`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures` # Keyword arguments -- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made +- `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains - `swap_every::Integer` steps are carried out between each attempt at a swap # See also - [`TemperedSampler`](@ref) - For more on the swap strategies: - - [`AbstractSwapStrategy`](@ref) - - [`StandardSwap`](@ref) - - [`NonReversibleSwap`](@ref) - - [`RandomPermutationSwap`](@ref) + - [`AbstractSwapStrategy`](@ref) + - [`ReversibleSwap`](@ref) + - [`NonReversibleSwap`](@ref) + - [`SingleSwap`](@ref) + - [`SingleRandomSwap`](@ref) + - [`RandomSwap`](@ref) + - [`NoSwap`](@ref) """ function tempered( sampler::AbstractMCMC.AbstractSampler, N_it::Integer; - swap_strategy::AbstractSwapStrategy=StandardSwap(), + swap_strategy::AbstractSwapStrategy=ReversibleSwap(), kwargs... ) return tempered(sampler, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy = swap_strategy, kwargs...) @@ -94,8 +97,8 @@ end function tempered( sampler::AbstractMCMC.AbstractSampler, inverse_temperatures::Vector{<:Real}; - swap_strategy::AbstractSwapStrategy=StandardSwap(), - swap_every::Integer=2, + swap_strategy::AbstractSwapStrategy=ReversibleSwap(), + swap_every::Integer=10, adapt::Bool=false, adapt_target::Real=0.234, adapt_stepsize::Real=1, @@ -104,7 +107,8 @@ function tempered( adapt_scale=defaultscale(adapt_schedule, inverse_temperatures), kwargs... ) - swap_every >= 2 || error("This must be a positive integer greater than 1.") + !(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.") + swap_every > 1 || error("`swap_every` must take a positive integer value greater than 1.") inverse_temperatures = check_inverse_temperatures(inverse_temperatures) adaptation_states = init_adaptation( adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize diff --git a/src/sampling.jl b/src/sampling.jl index 733158b..9f51ec0 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,24 +1,27 @@ """ tempered_sample([rng, ], model, sampler, N, inverse_temperatures; kwargs...) OR - tempered_sample([rng, ], model, sampler, N, N_it; swap_strategy=StandardSwap(), kwargs...) + tempered_sample([rng, ], model, sampler, N, N_it; swap_strategy=SingleSwap(), kwargs...) Generate `N` samples from `model` using a tempered version of the provided `sampler`. Provide either `inverse_temperatures` or `N_it` (and a `swap_strategy`) to generate some # Keyword arguments - `N_burnin::Integer` burn-in steps will be carried out before any swapping between chains is attempted -- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made +- `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains - `swap_every::Integer` steps are carried out between each attempt at a swap # See also - [`tempered`](@ref) - [`TemperedSampler`](@ref) - For more on the swap strategies: - - [`AbstractSwapStrategy`](@ref) - - [`StandardSwap`](@ref) - - [`NonReversibleSwap`](@ref) - - [`RandomPermutationSwap`](@ref) + - [`AbstractSwapStrategy`](@ref) + - [`ReversibleSwap`](@ref) + - [`NonReversibleSwap`](@ref) + - [`SingleSwap`](@ref) + - [`SingleRandomSwap`](@ref) + - [`RandomSwap`](@ref) + - [`NoSwap`](@ref) """ function tempered_sample( model, @@ -36,7 +39,7 @@ function tempered_sample( sampler::AbstractMCMC.AbstractSampler, N::Integer, N_it::Integer; - swap_strategy::AbstractSwapStrategy = StandardSwap(), + swap_strategy::AbstractSwapStrategy = SingleSwap(), kwargs... ) return tempered_sample(model, sampler, N, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy=swap_strategy, kwargs...) diff --git a/src/stepping.jl b/src/stepping.jl index fb465e5..83d35fa 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -4,7 +4,7 @@ Return `true` if a swap should happen at this iteration, and `false` otherwise. """ function should_swap(sampler::TemperedSampler, state::TemperedState) - return state.total_steps % sampler.swap_every == 0 + return state.total_steps % sampler.swap_every == 1 end get_init_params(x, _)= x @@ -86,7 +86,7 @@ function AbstractMCMC.step( state::TemperedState; kwargs... ) - # Reset. + # Reset state @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) if should_swap(sampler, state) @@ -134,7 +134,7 @@ end """ swap_step([strategy::AbstractSwapStrategy, ]rng, model, sampler, state) -Return new `state`, now with temperatures swapped according to `strategy`. +Return a new `state`, with temperatures possibly swapped according to `strategy`. If no `strategy` is provided, the return-value of [`swapstrategy`](@ref) called on `sampler` is used. @@ -149,56 +149,89 @@ function swap_step( end function swap_step( - strategy::StandardSwap, + strategy::ReversibleSwap, rng::Random.AbstractRNG, model, sampler::TemperedSampler, state::TemperedState ) - L = numtemps(sampler) - 1 - k = rand(rng, 1:L) - return swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps / L) + # Randomly select whether to attempt swaps between chains + # corresponding to odd or even indices of the temperature ladder + odd = rand([true, false]) + for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))] + state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + end + return state end function swap_step( - strategy::RandomPermutationSwap, + strategy::NonReversibleSwap, rng::Random.AbstractRNG, model, sampler::TemperedSampler, state::TemperedState ) - L = numtemps(sampler) - 1 - levels = Vector{Int}(undef, L) - Random.randperm!(rng, levels) - - # Iterate through all levels and attempt swaps. - for k in levels - state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) + # Alternate between attempting to swap chains corresponding + # to odd and even indices of the temperature ladder + odd = state.total_steps % (2 * sampler.swap_every) != 0 + for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))] + state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) end return state end function swap_step( - strategy::NonReversibleSwap, + strategy::SingleSwap, rng::Random.AbstractRNG, model, sampler::TemperedSampler, state::TemperedState ) - L = numtemps(sampler) - 1 - # Alternate between swapping odds and evens. - levels = if state.total_steps % (2 * sampler.swap_every) == 0 - 1:2:L - else - 2:2:L - end + # Randomly pick one index `k` of the temperature ladder and + # attempt a swap between the corresponding chain and its neighbour + k = rand(rng, 1:(numtemps(sampler) - 1)) + return swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) +end + +function swap_step( + strategy::SingleRandomSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState +) + # Randomly pick two temperature ladder indices in order to + # attempt a swap between the corresponding chains + chains = Set(1:numtemps(sampler)) + i = pop!(chains, rand(rng, chains)) + j = pop!(chains, rand(rng, chains)) + return swap_attempt(rng, model, sampler, state, i, j, sampler.adapt) +end - # Iterate through all levels and attempt swaps. - for k in levels - # TODO: For this swapping strategy, we should really be using the adaptation from Syed et. al. (2019), - # but with the current one: shouldn't we at least divide `state.total_steps` by 2 since it will - # take use two swap-attempts before we have tried swapping all of them (in expectation). - state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) +function swap_step( + strategy::RandomSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState +) + # Iterate through all of temperature ladder indices, picking random + # pairs and attempting swaps between the corresponding chains + chains = Set(1:numtemps(sampler)) + while length(chains) >= 2 + i = pop!(chains, rand(rng, chains)) + j = pop!(chains, rand(rng, chains)) + state = swap_attempt(rng, model, sampler, state, i, j, sampler.adapt) end return state end + +function swap_step( + strategy::NoSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState +) + return state +end \ No newline at end of file diff --git a/src/swapping.jl b/src/swapping.jl index 5412dfa..578aa9c 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -8,59 +8,94 @@ A concrete subtype is expected to implement the method [`swap_step`](@ref). abstract type AbstractSwapStrategy end """ - StandardSwap <: AbstractSwapStrategy + ReversibleSwap <: AbstractSwapStrategy -At every swap step taken, this strategy samples a single chain index `i` and proposes -a swap between chains `i` and `i + 1`. +Stochastically attempt either even- or odd-indexed swap moves between chains. -This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05] +See [^SYED19] for more on this approach, referred to as SEO in their paper. # References -[^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). +[^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). """ -struct StandardSwap <: AbstractSwapStrategy end +struct ReversibleSwap <: AbstractSwapStrategy end """ - RandomPermutationSwap <: AbstractSwapStrategy + NonReversibleSwap <: AbstractSwapStrategy + +At every swap step taken, this strategy _deterministically_ traverses +first the odd chain indices, proposing swaps between neighbors, and +then in the _next_ swap step taken traverses even chain indices, proposing +swaps between neighbors. -At every swap step taken, this strategy randomly shuffles all the chain indices -and then iterates through them, proposing swaps for neighboring chains. +See [^SYED19] for more on this approach, referred to as DEO in their paper. + +# References +[^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). """ -struct RandomPermutationSwap <: AbstractSwapStrategy end +struct NonReversibleSwap <: AbstractSwapStrategy end + +""" + SingleSwap <: AbstractSwapStrategy +At every swap step taken, this strategy samples a single chain index +`i` and proposes a swap between chains `i` and `i + 1`. + +This approach goes under a number of names, e.g. Parallel Tempering +(PT) MCMC and Replica-Exchange MCMC.[^PTPH05] + +# References +[^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). +""" +struct SingleSwap <: AbstractSwapStrategy end """ - NonReversibleSwap <: AbstractSwapStrategy + SingleRandomSwap <: AbstractSwapStrategy -At every swap step taken, this strategy _deterministically_ traverses first the -odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step -taken traverses even chain indices, proposing swaps between neighbors. +At every swap step taken, this strategy samples two chain indices +`i` and 'j' and proposes a swap between the two corresponding chains. -See [^SYED19] for more on this approach. +This approach is shown to be effective for certain models in [^1]. # References -[^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). +[^1]: Malcolm Sambridge, A Parallel Tempering algorithm for probabilistic sampling and multimodal optimization, Geophysical Journal International, Volume 196, Issue 1, January 2014, Pages 357–374, https://doi.org/10.1093/gji/ggt342 """ -struct NonReversibleSwap <: AbstractSwapStrategy end +struct SingleRandomSwap <: AbstractSwapStrategy end + +""" + RandomSwap <: AbstractSwapStrategy + +This strategy randomly shuffles all the chain indices to produce +`floor(numptemps(sampler)/2)` pairs of random (not necessarily +neighbouring) chain indices to attempt to swap +""" +struct RandomSwap <: AbstractSwapStrategy end + +""" + NoSwap <: AbstractSwapStrategy + +Mainly useful for debugging or observing each chain independently, +this overrides and disables all swapping functionality. +""" +struct NoSwap <: AbstractSwapStrategy end """ - swap_betas!(chain_to_process, process_to_chain, k) + swap_betas!(chain_to_process, process_to_chain, i, j) -Swaps the `k`th and `k + 1`th temperatures in place. +Swaps the `i`th and `j`th temperatures in place. """ -function swap_betas!(chain_to_process, process_to_chain, k) +function swap_betas!(chain_to_process, process_to_chain, i, j) # TODO: Use BangBang's `@set!!` to also support tuples? # Extract the process index for each of the chains. - process_for_chain_k, process_for_chain_kp1 = chain_to_process[k], chain_to_process[k + 1] + process_for_chain_i, process_for_chain_j = chain_to_process[i], chain_to_process[j] # Switch the mapping of the `chain → process` map. - # The temperature for the k-th chain will now be moved from its current process - # to the process for the (k + 1)-th chain, and vice versa. - chain_to_process[k], chain_to_process[k + 1] = process_for_chain_kp1, process_for_chain_k + # The temperature for the i-th chain will now be moved from its current process + # to the process for the (j)-th chain, and vice versa. + chain_to_process[i], chain_to_process[j] = process_for_chain_j, process_for_chain_i # Swap the mapping of the `process → chain` map. - # The process that used to have the k-th chain, now has the (k+1)-th chain, and vice versa. - process_to_chain[process_for_chain_k], process_to_chain[process_for_chain_kp1] = k + 1, k + # The process that used to have the i-th chain, now has the (i+1)-th chain, and vice versa. + process_to_chain[process_for_chain_i], process_to_chain[process_for_chain_j] = j, i return chain_to_process, process_to_chain end @@ -89,61 +124,60 @@ function compute_tempered_logdensities( end """ - swap_acceptance_pt(logπk, logπkp1) + swap_acceptance_pt(logπi, logπj) Calculates and returns the swap acceptance ratio for swapping the temperature -of two chains. Using tempered likelihoods `logπk` and `logπkp1` at the chains' +of two chains. Using tempered likelihoods `logπi` and `logπj` at the chains' current state parameters. """ -function swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) - return (logπkp1_θk + logπk_θkp1) - (logπk_θk + logπkp1_θkp1) +function swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) + return (logπjθi + logπiθj) - (logπiθi + logπjθj) end """ - swap_attempt(rng, model, sampler, state, k, adapt) + swap_attempt(rng, model, sampler, state, i, j) Attempt to swap the temperatures of two chains by tempering the densities and calculating the swap acceptance ratio; then swapping if it is accepted. """ -function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) - # TODO: Allow arbitrary `k` rather than just `k + 1`. +function swap_attempt(rng, model, sampler, state, i, j, adapt) # Extract the relevant transitions. - samplerk = sampler_for_chain(sampler, state, k) - samplerkp1 = sampler_for_chain(sampler, state, k + 1) - transitionk = transition_for_chain(state, k) - transitionkp1 = transition_for_chain(state, k + 1) - statek = state_for_chain(state, k) - statekp1 = state_for_chain(state, k + 1) - βk = beta_for_chain(state, k) - βkp1 = beta_for_chain(state, k + 1) + sampler_i = sampler_for_chain(sampler, state, i) + sampler_j = sampler_for_chain(sampler, state, j) + transition_i = transition_for_chain(state, i) + transition_j = transition_for_chain(state, j) + state_i = state_for_chain(state, i) + state_j = state_for_chain(state, j) + β_i = beta_for_chain(state, i) + β_j = beta_for_chain(state, j) # Evaluate logdensity for both parameters for each tempered density. - logπk_θk, logπk_θkp1 = compute_tempered_logdensities( - model, samplerk, samplerkp1, transitionk, transitionkp1, statek, statekp1, βk, βkp1 + logπiθi, logπiθj = compute_tempered_logdensities( + model, sampler_i, sampler_j, transition_i, transition_j, state_i, state_j, β_i, β_j ) - logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities( - model, samplerkp1, samplerk, transitionkp1, transitionk, statekp1, statek, βkp1, βk + logπjθj, logπjθi = compute_tempered_logdensities( + model, sampler_j, sampler_i, transition_j, transition_i, state_j, state_i, β_j, β_i ) # If the proposed temperature swap is accepted according `logα`, # swap the temperatures for future steps. - logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) + logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) should_swap = -Random.randexp(rng) ≤ logα if should_swap - swap_betas!(state.chain_to_process, state.process_to_chain, k) + swap_betas!(state.chain_to_process, state.process_to_chain, i, j) end # Keep track of the (log) acceptance ratios. - state.swap_acceptance_ratios[k] = logα + state.swap_acceptance_ratios[i] = logα # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is # adapted before a new `inverse_temperatures` is generated and returned. if adapt ρs = adapt!!( - state.adaptation_states, state.inverse_temperatures, k, min(one(logα), exp(logα)) + state.adaptation_states, state.chain_to_beta, i, min(one(logα), exp(logα)) ) @set! state.adaptation_states = ρs - @set! state.inverse_temperatures = update_inverse_temperatures(ρs, state.inverse_temperatures) + @set! state.chain_to_beta = update_inverse_temperatures(ρs, state.chain_to_beta) end return state end diff --git a/test/runtests.jl b/test/runtests.jl index 38cebd8..e13411b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,7 +50,7 @@ function test_and_sample_model( model, sampler, inverse_temperatures, - swap_strategy=MCMCTempering.StandardSwap(); + swap_strategy=MCMCTempering.SingleSwap(); mean_swap_rate_bound=0.1, compare_mean_swap_rate=≥, num_iterations=2_000, @@ -64,7 +64,7 @@ function test_and_sample_model( kwargs... ) # TODO: Remove this when no longer necessary. - num_iterations_tempered = Int(ceil(num_iterations * swap_every ÷ (swap_every - 1))) + num_iterations_tempered = Int(ceil(num_iterations * swap_every / (swap_every - 1))) # Make the tempered sampler. sampler_tempered = tempered( @@ -149,7 +149,7 @@ function test_and_sample_model( # Compare the tempered sampler to the untempered sampler. state_tempered = states_tempered[end] chain_tempered = AbstractMCMC.bundle_samples( - samples_tempered, + samples_tempered[findall((!).(getproperty.(states_tempered, :is_swap)))], MCMCTempering.maybe_wrap_model(model), sampler_tempered.sampler, MCMCTempering.state_for_chain(state_tempered), @@ -190,12 +190,10 @@ function compare_chains( if compare_ess ess = MCMCChains.ess_rhat(chain).nt.ess ess_tempered = MCMCChains.ess_rhat(chain_tempered).nt.ess - # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would - # actually do better than the internal sampler. if isbroken - @test_broken all(ess .≥ ess_tempered .* 0.5) + @test_broken all(ess_tempered .≥ ess) else - @test all(ess .≥ ess_tempered .* 0.5) + @test all(ess_tempered .≥ ess) end end end @@ -215,7 +213,7 @@ end chain_to_beta = [1.0, 0.75, 0.5, 0.25] # Make swap chain 1 (now on process 1) ↔ chain 2 (now on process 2) - MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1) + MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1, 2) # Expected result: chain 1 is now on process 2, chain 2 is now on process 1. target_process_to_chain = [2, 1, 3, 4] @test process_to_chain[chain_to_process] == 1:length(process_to_chain) @@ -231,7 +229,7 @@ end end # Make swap chain 2 (now on process 1) ↔ chain 3 (now on process 3) - MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2) + MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2, 3) # Expected result: chain 3 is now on process 1, chain 2 is now on process 3. target_process_to_chain = [3, 1, 2, 4] @test process_to_chain[chain_to_process] == 1:length(process_to_chain) @@ -271,6 +269,7 @@ end # `atol` is fairly high because we haven't run this for "too" long. @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.2 end + @testset "GMM 1D" begin num_iterations = 10_000 model = DistributionLogDensity( @@ -325,9 +324,11 @@ end # Different swap strategies to test. swapstrategies = [ - MCMCTempering.StandardSwap(), - MCMCTempering.RandomPermutationSwap(), - MCMCTempering.NonReversibleSwap() + MCMCTempering.ReversibleSwap(), + MCMCTempering.NonReversibleSwap(), + MCMCTempering.SingleSwap(), + MCMCTempering.SingleRandomSwap(), + MCMCTempering.RandomSwap() ] @testset "$(swapstrategy)" for swapstrategy in swapstrategies @@ -354,7 +355,7 @@ end # Get the parameter names. param_names = map(Symbol, DynamicPPL.TestUtils.varnames(model_dppl)) # Get bijector so we can get back to unconstrained space afterwards. - b = inv(Turing.bijector(model_dppl)) + b = inverse(Turing.bijector(model_dppl)) # Construct the `LogDensityFunction` which supports LogDensityProblems.jl-interface. model = ADgradient(:ForwardDiff, DynamicPPL.LogDensityFunction(model_dppl, vi)) @@ -367,14 +368,15 @@ end end @testset "AdvancedHMC.jl" begin - num_iterations = 1_000 + num_iterations = 2_000 # Set up HMC smpler. initial_ϵ = 0.1 integrator = AdvancedHMC.Leapfrog(initial_ϵ) proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator) metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model)) - sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric) + adaptor = AdvancedHMC.StanHMCAdaptor(AdvancedHMC.MassMatrixAdaptor(metric), AdvancedHMC.StepSizeAdaptor(0.8, integrator)) + sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, adaptor) # Sample using HMC. samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false) @@ -388,8 +390,8 @@ end chain_tempered = test_and_sample_model( model, sampler_hmc, - [1, 0.9, 0.75, 0.5, 0.25, 0.1], - swap_strategy=MCMCTempering.NonReversibleSwap(), + [1, 0.25, 0.1, 0.01], + swap_strategy=MCMCTempering.ReversibleSwap(), num_iterations=num_iterations, swap_every=10, adapt=false, @@ -428,7 +430,7 @@ end model, sampler_mh, [1, 0.9, 0.75, 0.5, 0.25, 0.1], - swap_strategy=MCMCTempering.StandardSwap(), + swap_strategy=MCMCTempering.ReversibleSwap(), num_iterations=num_iterations, swap_every=2, adapt=false, @@ -437,9 +439,9 @@ end param_names=param_names ) map_parameters!(b, chain_tempered) - - # TODO: Make it not broken, i.e. produce reasonable results. - compare_chains(chain_mh, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false) + + # Need a large atol as MH is not great on its own + compare_chains(chain_mh, chain_tempered, atol=0.4, compare_std=false, compare_ess=true, isbroken=false) end end end