From c4c16417234f023c90e3fcf3f45862e3acd7ed72 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 11 Mar 2023 14:27:05 +0000 Subject: [PATCH] Introduction of compositions and products of samplers and models (#151) * added CompositionSampler, RepeatedSampler, MultiSampler together with additional methods for meta-type samplers * added LinearAlgebra as dep * big update but now everything finally works * added additional pass-on-methods for meta-samplers and moved the bundle_samples to a more appropriate place * renamed state_from_state to state_from and changed the ordering of the args to be more reasonable * added some missing methods and fixed a typo * added model_for_chain and model_for_process similar to other utility methods for interacting with the tempered state, etc. * added todo * moved bundling back to ordering of defintions * added missing test dep * increase number of steps for one of the tests * specialize step for combination of RepeatedSampler and MultiSampler * Update src/sampler.jl Co-authored-by: Harrison Wilde * Introduction of `SwapSampler` + make `TemperedSampler` a fancy version of `CompositionSampler` (#152) * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed unnecessary variable in tests * Update src/sampler.jl Co-authored-by: Harrison Wilde * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * removed unnecessary variable in tests * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * Apply suggestions from code review Co-authored-by: Harrison Wilde * added CompositionTransition + quite a few bundle_samples with a `bundle_resolve_swaps` kwarg to allow converting into chains more easily * more samples * reduce requirement for ess comparison for AHMC a bit * significant improvements to the simple Gaussian example, now testing using MCSE to get tolerances, etc. and small improvements to the rest of the tests * trying to debug these tests * more debug * fixed typy * reduce significance even further --------- Co-authored-by: Harrison Wilde --------- Co-authored-by: Harrison Wilde --- Project.toml | 2 + src/MCMCTempering.jl | 194 ++++++++++++++++++++++++++++- src/abstractmcmc.jl | 106 ++++++++++++++++ src/adaptation.jl | 2 +- src/ladders.jl | 4 +- src/sampler.jl | 108 ++++++++++++---- src/samplers/composition.jl | 130 ++++++++++++++++++++ src/samplers/multi.jl | 210 ++++++++++++++++++++++++++++++++ src/samplers/repeated.jl | 81 ++++++++++++ src/state.jl | 154 ++++++++++++++--------- src/stepping.jl | 225 ++++++++++++++-------------------- src/swapping.jl | 68 +++-------- src/swapsampler.jl | 224 ++++++++++++++++++++++++++++++++++ src/utils.jl | 34 ++++++ test/Project.toml | 8 +- test/abstractmcmc.jl | 186 ++++++++++++++++++++++++++++ test/compat.jl | 28 ++++- test/runtests.jl | 237 ++++++++++++++++++++---------------- test/setup.jl | 22 ++++ test/simple_gaussian.jl | 74 +++++++++++ test/test_utils.jl | 126 +++++++++++++++++++ 21 files changed, 1835 insertions(+), 388 deletions(-) create mode 100644 src/abstractmcmc.jl create mode 100644 src/samplers/composition.jl create mode 100644 src/samplers/multi.jl create mode 100644 src/samplers/repeated.jl create mode 100644 src/swapsampler.jl create mode 100644 src/utils.jl create mode 100644 test/abstractmcmc.jl create mode 100644 test/setup.jl create mode 100644 test/simple_gaussian.jl create mode 100644 test/test_utils.jl diff --git a/Project.toml b/Project.toml index b369ba1..c6be239 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,9 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index a8a084d..cb658d5 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -9,19 +9,24 @@ using ProgressLogging: ProgressLogging using ConcreteStructs: @concrete using Setfield: @set, @set! +using MCMCChains: MCMCChains + using InverseFunctions using DocStringExtensions include("logdensityproblems.jl") +include("abstractmcmc.jl") include("adaptation.jl") include("swapping.jl") include("state.jl") +include("swapsampler.jl") include("sampler.jl") include("sampling.jl") include("ladders.jl") include("stepping.jl") include("model.jl") +include("utils.jl") export tempered, tempered_sample, @@ -39,16 +44,199 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model +# Bundling. +# Bundling of non-tempered samples. +function bundle_nontempered_samples( + ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, + model::AbstractMCMC.AbstractModel, + sampler::TemperedSampler, + state::TemperedState, + ::Type{T}; + kwargs... +) where {T} + # Create the same model and sampler as we do in the initial step for `TemperedSampler`. + multimodel = MultiModel([ + make_tempered_model(sampler, model, sampler.chain_to_beta[i]) + for i in 1:numtemps(sampler) + ]) + multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)]) + multitransitions = [ + MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions)) + for t in ts + ] + + return AbstractMCMC.bundle_samples( + multitransitions, + multimodel, + multisampler, + MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)), + T + ) +end + function AbstractMCMC.bundle_samples( - ts::AbstractVector, + ts::Vector{<:MultipleTransitions}, + model::MultiModel, + sampler::MultiSampler, + state::MultipleStates, + # TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.? + ::Type{Vector{MCMCChains.Chains}}; + kwargs... +) + return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state + AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...) + end +end + +# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118 +function AbstractMCMC.bundle_samples( + ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, model::AbstractMCMC.AbstractModel, sampler::TemperedSampler, state::TemperedState, - chain_type::Type; + ::Type{Vector{T}}; + bundle_resolve_swaps::Bool=false, + kwargs... +) where {T} + if bundle_resolve_swaps + return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...) + end + + # TODO: Do better? + return ts +end + +function AbstractMCMC.bundle_samples( + ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, + model::AbstractMCMC.AbstractModel, + sampler::TemperedSampler, + state::TemperedState, + ::Type{MCMCChains.Chains}; kwargs... ) + # Extract the transitions ordered, which are ordered according to processes, according to the chains. + ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts] + return AbstractMCMC.bundle_samples( + ts_actual, + model, + sampler_for_chain(sampler, state, 1), + state_for_chain(state, 1), + MCMCChains.Chains; + kwargs... + ) +end + +function AbstractMCMC.bundle_samples( + ts::AbstractVector, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler, + state::CompositionState, + ::Type{T}; + kwargs... +) where {T} + # In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate + # the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions. + ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts + # TODO: Should we really always default to outer sampler? + return AbstractMCMC.bundle_samples( + ts_actual, model, sampler.sampler_outer, state.state_outer, T; + kwargs... + ) +end + +# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118 +function AbstractMCMC.bundle_samples( + ts::Vector, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler, + state::CompositionState, + ::Type{Vector{T}}; + kwargs... +) where {T} + if !saveall(sampler) + # In this case, we just use the `outer` for everything since this is the only + # transitions we're keeping around. + return AbstractMCMC.bundle_samples( + ts, model, sampler.sampler_outer, state.state_outer, Vector{T}; + kwargs... + ) + end + + # Otherwise, we don't know what to do. + return ts +end + +function AbstractMCMC.bundle_samples( + ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}}, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler{<:MultiSampler,<:SwapSampler}, + state::CompositionState{<:MultipleStates,<:SwapState}, + ::Type{T}; + bundle_resolve_swaps::Bool=false, + kwargs... +) where {T} + !bundle_resolve_swaps && return ts + + # Resolve the swaps. + sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false) + ts_actual = map(ts) do t + composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t)) + end + AbstractMCMC.bundle_samples( - ts, maybe_wrap_model(model), sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type; + ts_actual, model, sampler.sampler_outer, state.state_outer, T; + kwargs... + ) +end + +# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118 +function AbstractMCMC.bundle_samples( + ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}}, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler{<:MultiSampler,<:SwapSampler}, + state::CompositionState{<:MultipleStates,<:SwapState}, + ::Type{Vector{T}}; + bundle_resolve_swaps::Bool=false, + kwargs... +) where {T} + !bundle_resolve_swaps && return ts + + # Resolve the swaps (using the already implemented resolution in `composition_transition` + # for this particular sampler but without `saveall`). + sampler_without_saveall = @set sampler.saveall = Val(false) + ts_actual = map(ts) do t + composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t)) + end + + return AbstractMCMC.bundle_samples( + ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T}; + kwargs... + ) +end + +function AbstractMCMC.bundle_samples( + ts::AbstractVector, + model::AbstractMCMC.AbstractModel, + sampler::RepeatedSampler, + state, + ::Type{MCMCChains.Chains}; + kwargs... +) + return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...) +end + +# Unflatten in the case of `SequentialTransitions`. +function AbstractMCMC.bundle_samples( + ts::AbstractVector{<:SequentialTransitions}, + model::AbstractMCMC.AbstractModel, + sampler::RepeatedSampler, + state::SequentialStates, + ::Type{MCMCChains.Chains}; + kwargs... +) + ts_actual = [t for tseq in ts for t in tseq.transitions] + return AbstractMCMC.bundle_samples( + ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains; kwargs... ) end diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl new file mode 100644 index 0000000..716bb2c --- /dev/null +++ b/src/abstractmcmc.jl @@ -0,0 +1,106 @@ +using Setfield +using AbstractMCMC: AbstractMCMC + +import LinearAlgebra: × + +""" + getparams([model, ]state) + +Get the parameters from the `state`. + +Default implementation uses [`getparams_and_logprob`](@ref). +""" +getparams(state) = first(getparams_and_logprob(state)) +getparams(model, state) = first(getparams_and_logprob(model, state)) + +""" + getlogprob([model, ]state) + +Get the log probability of the `state`. + +Default implementation uses [`getparams_and_logprob`](@ref). +""" +getlogprob(state) = last(getparams_and_logprob(state)) +getlogprob(model, state) = last(getparams_and_logprob(model, state)) + +""" + getparams_and_logprob([model, ]state) + +Return a vector of parameters from the `state`. + +See also: [`setparams_and_logprob!!`](@ref). +""" +getparams_and_logprob(model, state) = getparams_and_logprob(state) + +""" + setparams_and_logprob!!([model, ]state, params) + +Set the parameters in the state to `params`, possibly mutating if it makes sense. + +See also: [`getparams_and_logprob`](@ref). +""" +setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob) + +""" + state_from(model, state_target, state_source[, transition_source, transition_target]) + +Return a new state similar to `state_target` but updated from `state_source`, which could be +a different type of state. +""" +function state_from(model, state_target, state_source, transition_target, transition_source) + return state_from(model, state_target, state_source) +end +function state_from(model, state_target, state_source) + params, logp = getparams_and_logprob(model, state_source) + return setparams_and_logprob!!(model, state_target, params, logp) +end + +""" + SequentialTransitions + +A `SequentialTransitions` object is a container for a sequence of transitions. +""" +struct SequentialTransitions{A} + transitions::A +end + +# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the +# last transition/state. +getparams_and_logprob(transitions::SequentialTransitions) = getparams_and_logprob(transitions.transitions[end]) +function getparams_and_logprob(model, transitions::SequentialTransitions) + return getparams_and_logprob(model, transitions.transitions[end]) +end + +function setparams_and_logprob!!(transitions::SequentialTransitions, params, logprob) + return @set transitions.transitions[end] = setparams_and_logprob!!(transitions.transitions[end], params, logprob) +end +function setparams_and_logprob!!(model, transitions::SequentialTransitions, params, logprob) + return @set transitions.transitions[end] = setparams_and_logprob!!(model, transitions.transitions[end], params, logprob) +end + +""" + SequentialStates + +A `SequentialStates` object is a container for a sequence of states. +""" +struct SequentialStates{A} + states::A +end + +# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the +# last transition/state. +getparams_and_logprob(state::SequentialStates) = getparams_and_logprob(state.states[end]) +getparams_and_logprob(model, state::SequentialStates) = getparams_and_logprob(model, state.states[end]) + +function setparams_and_logprob!!(state::SequentialStates, params, logprob) + return @set state.states[end] = setparams_and_logprob!!(state.states[end], params, logprob) +end +function setparams_and_logprob!!(model, state::SequentialStates, params, logprob) + return @set state.states[end] = setparams_and_logprob!!(model, state.states[end], params, logprob) +end + +# Includes. +include("samplers/composition.jl") +include("samplers/repeated.jl") +include("samplers/multi.jl") + diff --git a/src/adaptation.jl b/src/adaptation.jl index 096ca13..25d640b 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -19,7 +19,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and """ struct Geometric end -defaultscale(::Geometric, Δ) = eltype(Δ)(0.9) +defaultscale(::Geometric, Δ) = float(eltype(Δ))(0.9) """ InverselyAdditive diff --git a/src/ladders.jl b/src/ladders.jl index 0ebf615..28305d1 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -37,9 +37,7 @@ end Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0` """ function check_inverse_temperatures(Δ) - if length(Δ) <= 1 - error("More than one inverse temperatures must be provided.") - end + !isempty(Δ) || error("Inverse temperatures array is empty.") if !all(zero.(Δ) .≤ Δ .≤ one.(Δ)) error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].") end diff --git a/src/sampler.jl b/src/sampler.jl index 08f157c..7ecd097 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,3 +1,20 @@ +""" + TemperedState + +A state for a tempered sampler. + +# Fields +$(FIELDS) +""" +@concrete struct TemperedState + "state for swap-sampler" + swapstate + "state for the main sampler" + state + "inverse temperature for each of the chains" + chain_to_beta +end + """ TemperedSampler <: AbstractMCMC.AbstractSampler @@ -7,44 +24,39 @@ A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Temp $(FIELDS) """ -@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler +Base.@kwdef struct TemperedSampler{SplT,A,SwapT,Adapt} <: AbstractMCMC.AbstractSampler "sampler(s) used to target the tempered distributions" - sampler + sampler::SplT "collection of inverse temperatures β; β[i] correponds i-th tempered model" - inverse_temperatures - "number of steps of `sampler` to take before proposing swaps" - swap_every - "the swap strategy that will be used when proposing swaps" - swap_strategy - # TODO: This should be replaced with `P` just being some `NoAdapt` type. + chain_to_beta::A + "strategy to use for swapping" + swapstrategy::SwapT=ReversibleSwap() + # TODO: Remove `adapt` and just consider `adaptation_states=nothing` as no adaptation. "boolean flag specifying whether or not to adapt" - adapt + adapt=false "adaptation parameters" - adaptation_states + adaptation_states::Adapt=nothing end -swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy +TemperedSampler(sampler, chain_to_beta; kwargs...) = TemperedSampler(; sampler, chain_to_beta, kwargs...) + +swapsampler(sampler::TemperedSampler) = SwapSampler(sampler.swapstrategy) +# TODO: Do we need this now? getsampler(samplers, I...) = getindex(samplers, I...) getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...) -""" - numsteps(sampler::TemperedSampler) - -Return number of inverse temperatures used by `sampler`. -""" -numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures) +chain_to_process(state::TemperedState, I...) = chain_to_process(state.swapstate, I...) +process_to_chain(state::TemperedState, I...) = process_to_chain(state.swapstate, I...) """ - sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) + sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) Return the sampler corresponding to the chain indexed by `I...`. -If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned. """ -sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, chain_to_process(state, I...)) + return sampler_for_process(sampler, state, chain_to_process(state, I...)) end """ @@ -53,9 +65,51 @@ end Return the sampler corresponding to the process indexed by `I...`. """ function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, I...) + return _sampler_for_process_temper(sampler.sampler, state, I...) end +# If `sampler` is a `MultiSampler`, we assume it's ordered according to chains. +_sampler_for_process_temper(sampler::MultiSampler, state, I...) = sampler.samplers[process_to_chain(state, I...)] +# Otherwise, we just use the same sampler for everything. +_sampler_for_process_temper(sampler, state, I...) = sampler + +# Defer extracting the corresponding state to the `swapstate`. +state_for_process(state::TemperedState, I...) = state_for_process(state.swapstate, I...) + +# Here we make the model(s) using the temperatures. +function model_for_process(sampler::TemperedSampler, model, state::TemperedState, I...) + return make_tempered_model(sampler, model, beta_for_process(state, I...)) +end + +""" + beta_for_chain(state[, I...]) + +Return the β corresponding to the chain indexed by `I...`. +If `I...` is not specified, the β corresponding to `β=1.0` will be returned. +""" +beta_for_chain(state::TemperedState) = beta_for_chain(state, 1) +beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...) +# NOTE: Array impl. is useful for testing. +beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...] + +""" + beta_for_process(state, I...) + +Return the β corresponding to the process indexed by `I...`. +""" +beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.swapstate.process_to_chain, I...) +# NOTE: Array impl. is useful for testing. +function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...) + return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...)) +end + +""" + numsteps(sampler::TemperedSampler) + +Return number of inverse temperatures used by `sampler`. +""" +numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta) + """ tempered(sampler, inverse_temperatures; kwargs...) OR @@ -98,7 +152,8 @@ function tempered( sampler::AbstractMCMC.AbstractSampler, inverse_temperatures::Vector{<:Real}; swap_strategy::AbstractSwapStrategy=ReversibleSwap(), - swap_every::Integer=10, + # TODO: Change `swap_every` to something like `number_of_iterations_per_swap`. + steps_per_swap::Integer=1, adapt::Bool=false, adapt_target::Real=0.234, adapt_stepsize::Real=1, @@ -108,10 +163,13 @@ function tempered( kwargs... ) !(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.") - swap_every > 1 || error("`swap_every` must take a positive integer value greater than 1.") + steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.") inverse_temperatures = check_inverse_temperatures(inverse_temperatures) adaptation_states = init_adaptation( adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize ) - return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states) + # NOTE: We just make a repeated sampler for `sampler_inner`. + # TODO: Generalize. Allow passing in a `MultiSampler`, etc. + sampler_inner = sampler^steps_per_swap + return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states) end diff --git a/src/samplers/composition.jl b/src/samplers/composition.jl new file mode 100644 index 0000000..b8fb153 --- /dev/null +++ b/src/samplers/composition.jl @@ -0,0 +1,130 @@ +""" + CompositionSampler <: AbstractMCMC.AbstractSampler + +A `CompositionSampler` is a container for a sequence of samplers. + +# Fields +$(FIELDS) + +# Examples +```julia +composed_sampler = sampler_inner ∘ sampler_outer # or `CompositionSampler(sampler_inner, sampler_outer, Val(true))` +AbstractMCMC.step(rng, model, composed_sampler) # one step of `sampler_inner`, and one step of `sampler_outer` +``` +""" +struct CompositionSampler{S1,S2,SaveAll} <: AbstractMCMC.AbstractSampler + "The outer sampler" + sampler_outer::S1 + "The inner sampler" + sampler_inner::S2 + "Whether to save all the transitions or just the last one" + saveall::SaveAll +end + +CompositionSampler(sampler_outer, sampler_inner) = CompositionSampler(sampler_outer, sampler_inner, Val(true)) + +Base.:∘(s_outer::AbstractMCMC.AbstractSampler, s_inner::AbstractMCMC.AbstractSampler) = CompositionSampler(s_outer, s_inner) + +""" + saveall(sampler) + +Return whether the sampler saves all the transitions or just the last one. +""" +saveall(sampler::CompositionSampler) = sampler.saveall +saveall(::CompositionSampler{<:Any,<:Any,Val{SaveAll}}) where {SaveAll} = SaveAll + +""" + CompositionState + +A `CompositionState` is a container for a sequence of states. + +# Fields +$(FIELDS) +""" +struct CompositionState{S1,S2} + "The outer state" + state_outer::S1 + "The inner state" + state_inner::S2 +end + +getparams_and_logprob(state::CompositionState) = getparams_and_logprob(state.state_outer) +getparams_and_logprob(model, state::CompositionState) = getparams_and_logprob(model, state.state_outer) + +function setparams_and_logprob!!(state::CompositionState, params, logprob) + return @set state.state_outer = setparams_and_logprob!!(state.state_outer, params, logprob) +end +function setparams_and_logprob!!(model, state::CompositionState, params, logprob) + return @set state.state_outer = setparams_and_logprob!!(model, state.state_outer, params, logprob) +end + +struct CompositionTransition{S1,S2} + "The outer transition" + transition_outer::S1 + "The inner transition" + transition_inner::S2 +end + +# Useful functions for interacting with composition sampler and states. +inner_sampler(sampler::CompositionSampler) = sampler.sampler_inner +outer_sampler(sampler::CompositionSampler) = sampler.sampler_outer + +inner_state(state::CompositionState) = state.state_inner +outer_state(state::CompositionState) = state.state_outer + +inner_transition(transition::CompositionTransition) = transition.transition_inner +outer_transition(transition::CompositionTransition) = transition.transition_outer +outer_transition(transition) = transition # in case we don't have `saveall` + +# TODO: We really don't need to use `SequentialStates` here, do we? +composition_state(sampler, state_inner, state_outer) = CompositionState(state_outer, state_inner) +function composition_transition(sampler, transition_inner, transition_outer) + return if saveall(sampler) + CompositionTransition(transition_outer, transition_inner) + else + transition_outer + end +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler; + kwargs... +) + state_inner_initial = last(AbstractMCMC.step(rng, model, inner_sampler(sampler); kwargs...)) + state_outer_initial = last(AbstractMCMC.step(rng, model, outer_sampler(sampler); kwargs...)) + + # Create the composition state, and take a full step. + state = composition_state(sampler, state_inner_initial, state_outer_initial) + return AbstractMCMC.step(rng, model, sampler, state; kwargs...) +end + +# TODO: Do we even need two versions? We could technically use `SequentialStates` +# in place of `CompositionState` and just have one version. +# The annoying part here is that we'll have to check `saveall` on every `step` +# rather than just for the initial step. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler, + state; + kwargs... +) + state_inner_prev, state_outer_prev = inner_state(state), outer_state(state) + + # Update the inner state. + current_state_inner = state_from(model, state_inner_prev, state_outer_prev) + + # Take a step in the inner sampler. + transition_inner, state_inner = AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...) + + # Take a step in the outer sampler. + current_state_outer = state_from(model, state_outer_prev, state_inner) + transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...) + + return ( + composition_transition(sampler, transition_inner, transition_outer), + composition_state(sampler, state_inner, state_outer) + ) +end diff --git a/src/samplers/multi.jl b/src/samplers/multi.jl new file mode 100644 index 0000000..f007e00 --- /dev/null +++ b/src/samplers/multi.jl @@ -0,0 +1,210 @@ +# Multiple independent samplers. +combine(x::Tuple, y::Tuple) = (x..., y...) +combine(x::Tuple, y) = (x..., y) +combine(x, y::Tuple) = (x, y...) +combine(x::AbstractArray, y::AbstractArray) = vcat(x, y) +combine(x::AbstractArray, y) = vcat(x, y) +combine(x, y::AbstractArray) = vcat(x, y) +combine(x, y) = Iterators.flatten((x, y)) + + +""" + MultiSampler <: AbstractMCMC.AbstractSampler + +A `MultiSampler` is a container for multiple samplers. + +See also: [`MultiModel`](@ref). + +# Fields +$(FIELDS) + +# Examples +```julia +# `sampler1` targets `model1`, `sampler2` targets `model2`, etc. +multi_model = model1 × model2 × model3 # or `MultiModel((model1, model2, model3))` +multi_sampler = sampler1 × sampler2 × sampler3 # or `MultiSampler((sampler1, sampler2, sampler3))` +# Target the joint model. +AbstractMCMC.step(rng, multi_model, multi_sampler) +``` +""" +struct MultiSampler{S} <: AbstractMCMC.AbstractSampler + "The samplers" + samplers::S +end + +×(sampler1::AbstractMCMC.AbstractSampler, sampler2::AbstractMCMC.AbstractSampler) = MultiSampler((sampler1, sampler2)) +×(sampler1::MultiSampler, sampler2::AbstractMCMC.AbstractSampler) = MultiSampler(combine(sampler1.samplers, sampler2)) +×(sampler1::AbstractMCMC.AbstractSampler, sampler2::MultiSampler) = MultiSampler(combine(sampler1, sampler2.samplers)) +×(sampler1::MultiSampler, sampler2::MultiSampler) = MultiSampler(combine(sampler1.samplers, sampler2.samplers)) + +""" + MultiModel <: AbstractMCMC.AbstractModel + +A `MultiModel` is a container for multiple models. + +See also: [`MultiSampler`](@ref). + +# Fields +$(FIELDS) +""" +struct MultiModel{M} <: AbstractMCMC.AbstractModel + "The models" + models::M +end + +×(model1::AbstractMCMC.AbstractModel, model2::AbstractMCMC.AbstractModel) = MultiModel((model1, model2)) +×(model1::MultiModel, model2::AbstractMCMC.AbstractModel) = MultiModel(combine(model1.models, model2)) +×(model1::AbstractMCMC.AbstractModel, model2::MultiModel) = MultiModel(combine(model1, model2.models)) +×(model1::MultiModel, model2::MultiModel) = MultiModel(combine(model1.models, model2.models)) + +Base.length(model::MultiModel) = length(model.models) + +# TODO: Make these subtypes of `AbstractVector`? +""" + MultipleTransitions + +A container for multiple transitions. + +See also: [`MultipleStates`](@ref). + +# Fields +$(FIELDS) +""" +struct MultipleTransitions{A} + "The transitions" + transitions::A +end + +function getparams_and_logprob(transitions::MultipleTransitions) + params_and_logprobs = map(getparams_and_logprob, transitions.transitions) + return map(first, params_and_logprobs), map(last, params_and_logprobs) +end +function getparams_and_logprob(model::MultiModel, transitions::MultipleTransitions) + params_and_logprobs = map(getparams_and_logprob, model.models, transitions.transitions) + return map(first, params_and_logprobs), map(last, params_and_logprobs) +end + +""" + MultipleStates + +A container for multiple states. + +See also: [`MultipleTransitions`](@ref). + +# Fields +$(FIELDS) +""" +struct MultipleStates{A} + "The states" + states::A +end + +# NOTE: This is different from most of the other implementations of `getparams_and_logprob` +# as here we need to work with multiple models, transitions, and states. +function getparams_and_logprob(state::MultipleStates) + params_and_logprobs = map(getparams_and_logprob, state.states) + return map(first, params_and_logprobs), map(last, params_and_logprobs) +end +function getparams_and_logprob(model::MultiModel, state::MultipleStates) + params_and_logprobs = map(getparams_and_logprob, model.models, state.states) + return map(first, params_and_logprobs), map(last, params_and_logprobs) +end + +function setparams_and_logprob!!(state::MultipleStates, params, logprobs) + @assert length(params) == length(logprobs) == length(state.states) "The number of parameters and log probabilities must match the number of states." + return @set state.states = map(setparams_and_logprob!!, state.states, params, logprobs) +end +function setparams_and_logprob!!(model::MultiModel, state::MultipleStates, params, logprobs) + @assert length(model.models) == length(params) == length(logprobs) == length(state.states) "The number of models, states, parameters, and log probabilities must match." + return @set state.states = map(setparams_and_logprob!!, model.models, state.states, params, logprobs) +end + +# TODO: Clean this up. +initparams(model::MultiModel, init_params) = map(Base.Fix1(get_init_params, init_params), 1:length(model.models)) +initparams(model::MultiModel{<:Tuple}, init_params) = ntuple(length(model.models)) do i + get_init_params(init_params, i) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::MultiSampler; + init_params=nothing, + kwargs... +) + @assert length(model.models) == length(sampler.samplers) "Number of models and samplers must be equal" + + # TODO: Handle `init_params` properly. Make sure that they respect the container-types used in + # `MultiModel` and `MultiSampler`. + init_params_multi = initparams(model, init_params) + transition_and_states = asyncmap(model.models, sampler.samplers, init_params_multi) do model, sampler, init_params + AbstractMCMC.step(rng, model, sampler; init_params, kwargs...) + end + + return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::MultiSampler, + states::MultipleStates; + kwargs... +) + @assert length(model.models) == length(sampler.samplers) == length(states.states) "Number of models, samplers, and states must be equal." + + transition_and_states = asyncmap(model.models, sampler.samplers, states.states) do model, sampler, state + AbstractMCMC.step(rng, model, sampler, state; kwargs...) + end + + return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) +end + +# NOTE: In the case of a `RepeatedSampler{<:MultiSampler}`, it's better to, effectively, re-order +# the samplers so that we make a `MultiSampler` of `RepeatedSampler`s. +# We don't want to mutate the sampler, so instead we just convert the sequence of multi-states into +# a multi-state of sequential states, and then work with this ordering in subsequent calls to `step`. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + repeated_sampler::RepeatedSampler{<:MultiSampler}, + states::SequentialStates; + kwargs... +) + @debug "Working with RepeatedSampler{<:MultiSampler}; converting a sequence of multi-states into a multi-state of sequential states" + + multisampler = repeated_sampler.sampler + multistates = last(states.states) + @assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal." + transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state + # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation. + AbstractMCMC.step( + rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), SequentialStates([state]); + kwargs... + ) + end + + return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) +end + +# And then we define how `RepeatedSampler{<:MultiSampler}` should work with a `MultipleStates`. +# NOTE: If `saveall(sampler)` is `false`, this is also the implementation we'll hit. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + repeated_sampler::RepeatedSampler{<:MultiSampler}, + multistates::MultipleStates; + kwargs... +) + multisampler = repeated_sampler.sampler + @assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal." + transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state + # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation. + AbstractMCMC.step( + rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), state; + kwargs... + ) + end + + return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) +end diff --git a/src/samplers/repeated.jl b/src/samplers/repeated.jl new file mode 100644 index 0000000..8ffa232 --- /dev/null +++ b/src/samplers/repeated.jl @@ -0,0 +1,81 @@ +""" + RepeatedSampler <: AbstractMCMC.AbstractSampler + +A `RepeatedSampler` is a container for a sampler and a number of times to repeat it. + +# Fields +$(FIELDS) + +# Examples +```julia +repeated_sampler = sampler^10 # or `RepeatedSampler(sampler, 10, Val(true))` +AbstractMCMC.step(rng, model, repeated_sampler) # take 10 steps of `sampler` +``` +""" +struct RepeatedSampler{S,SaveAll} <: AbstractMCMC.AbstractSampler + "The sampler to repeat" + sampler::S + "The number of times to repeat the sampler" + num_repeat::Int + "Whether to save all the transitions or just the last one" + saveall::SaveAll +end + +RepeatedSampler(sampler, num_repeat) = RepeatedSampler(sampler, num_repeat, Val(true)) + +Base.@constprop :aggressive Base.:^(s::AbstractMCMC.AbstractSampler, n::Int) = RepeatedSampler(s, n, Val(true)) + +saveall(sampler::RepeatedSampler) = sampler.saveall +saveall(::RepeatedSampler{<:Any,Val{SaveAll}}) where {SaveAll} = SaveAll + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatedSampler; + kwargs... +) + state = last(AbstractMCMC.step(rng, model, sampler.sampler; kwargs...)) + state_repeated = saveall(sampler) ? SequentialStates([state]) : state + + return AbstractMCMC.step(rng, model, sampler, state_repeated; kwargs...) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatedSampler, + state; + kwargs... +) + # Take a step in the inner sampler. + transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...) + + # Take a step in the outer sampler. + for _ in 2:sampler.num_repeat + transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...) + end + + return transition, state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatedSampler, + state::SequentialStates; + kwargs... +) + # Take a step in the inner sampler. + transition, state_inner = AbstractMCMC.step(rng, model, sampler.sampler, state.states[end]; kwargs...) + + # Take a step in the outer sampler. + transitions = [transition] + states = [state_inner] + for _ in 2:sampler.num_repeat + transition, state_inner = AbstractMCMC.step(rng, model, sampler.sampler, state_inner; kwargs...) + push!(transitions, transition) + push!(states, state_inner) + end + + return SequentialTransitions(transitions), SequentialStates(states) +end diff --git a/src/state.jl b/src/state.jl index ac5fc36..96e6c98 100644 --- a/src/state.jl +++ b/src/state.jl @@ -1,5 +1,19 @@ """ - TemperedState + ProcessOrder + +Specifies that the `model` should be treated as process-ordered. +""" +struct ProcessOrder end + +""" + ChainOrder + +Specifies that the `model` should be treated as chain-ordered. +""" +struct ChainOrder end + +""" + SwapState A general implementation of a state for a [`TemperedSampler`](@ref). @@ -17,7 +31,7 @@ Moreover, suppose we also have 4 workers/processes for which we run these chains (can also be serial wlog). We can then perform a swap in two different ways: -1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. +1. Swap the the _states_ between each process, i.e. permute `transitions` and `states`. 2. Swap the _temperatures_ between each process, i.e. permute `chain_to_beta`. (1) is possibly the most intuitive approach since it means that the i-th worker/process @@ -52,69 +66,87 @@ Chains: process_to_chain chain_to_process inverse_temperatures[process_t In this case, the chain `X` can be reconstructed as: ```julia -X[1] = states[1].transitions_and_states[1] -X[2] = states[2].transitions_and_states[2] -X[3] = states[3].transitions_and_states[2] -X[4] = states[4].transitions_and_states[3] -X[5] = states[5].transitions_and_states[3] +X[1] = states[1].states[1] +X[2] = states[2].states[2] +X[3] = states[3].states[2] +X[4] = states[4].states[3] +X[5] = states[5].states[3] ``` +and similarly for the states. + The indices here are exactly those represented by `states[k].chain_to_process[1]`. """ -@concrete struct TemperedState - "collection of `(transition, state)` pairs for each process" - transitions_and_states - "collection of (inverse) temperatures β corresponding to each chain" - chain_to_beta +@concrete struct SwapState + "collection of states for each process" + states "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" chain_to_process "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" process_to_chain "total number of steps taken" total_steps - "number of burn-in steps taken" - burnin_steps - "contains all necessary information for adaptation of inverse_temperatures" - adaptation_states - "flag which specifies wether this was a swap-step or not" - is_swap "swap acceptance ratios on log-scale" swap_acceptance_ratios end +# TODO: Can we support more? +function SwapState(state::MultipleStates) + process_to_chain = collect(1:length(state.states)) + chain_to_process = copy(process_to_chain) + return SwapState(state.states, chain_to_process, process_to_chain, 1, Dict{Int,Float64}()) +end + +# Defer these to `MultipleStates`. +# TODO: What is the best way to implement these? Should we sort according to the chain indices +# to match the order of the models? +# getparams_and_logprob(state::SwapState) = getparams_and_logprob(MultipleStates(state.states)) +# getparams_and_logprob(model, state::SwapState) = getparams_and_logprob(model, MultipleStates(state.states)) + +function setparams_and_logprob!!(model, state::SwapState, params, logprobs) + # Use the `MultipleStates`'s implementation to update the underlying states. + multistate = setparams_and_logprob!!(model, MultipleStates(state.states), params, logprobs) + # Update the states! + return @set state.states = multistate.states +end + """ - process_to_chain(state, I...) + sort_by_chain(::ChainOrdering, state, xs) + sort_by_chain(::ProcessOrdering, state, xs) -Return the chain index corresponding to the process index `I`. +Return `xs` sorted according to the chain indices, as specified by `state`. """ -process_to_chain(state::TemperedState, I...) = process_to_chain(state.process_to_chain, I...) -# NOTE: Array impl. is useful for testing. -process_to_chain(proc2chain::AbstractArray, I...) = proc2chain[I...] +sort_by_chain(::ChainOrder, ::Any, xs) = xs +sort_by_chain(::ProcessOrder, state, xs) = [xs[chain_to_process(state, i)] for i = 1:length(xs)] +sort_by_chain(::ProcessOrder, state, xs::Tuple) = ntuple(i -> xs[chain_to_process(state, i)], length(xs)) """ - chain_to_process(state, I...) + sort_by_process(::ProcessOrdering, state, xs) + sort_by_process(::ChainOrdering, state, xs) -Return the process index corresponding to the chain index `I`. +Return `xs` sorted according to the process indices, as specified by `state`. """ -chain_to_process(state::TemperedState, I...) = chain_to_process(state.chain_to_process, I...) -# NOTE: Array impl. is useful for testing. -chain_to_process(chain2proc::AbstractArray, I...) = chain2proc[I...] +sort_by_process(::ProcessOrder, ::Any, xs) = xs +sort_by_process(::ChainOrder, state, xs) = [xs[process_to_chain(state, i)] for i = 1:length(xs)] +sort_by_process(::ChainOrder, state, xs::Tuple) = ntuple(i -> xs[process_to_chain(state, i)], length(xs)) """ - transition_for_chain(state[, I...]) + process_to_chain(state, I...) -Return the transition corresponding to the chain indexed by `I...`. -If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. +Return the chain index corresponding to the process index `I`. """ -transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) -transition_for_chain(state::TemperedState, I...) = transition_for_process(state, chain_to_process(state, I...)) +process_to_chain(state::SwapState, I...) = process_to_chain(state.process_to_chain, I...) +# NOTE: Array impl. is useful for testing. +process_to_chain(proc2chain, I...) = proc2chain[I...] """ - transition_for_process(state, I...) + chain_to_process(state, I...) -Return the transition corresponding to the process indexed by `I...`. +Return the process index corresponding to the chain index `I`. """ -transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1] +chain_to_process(state::SwapState, I...) = chain_to_process(state.chain_to_process, I...) +# NOTE: Array impl. is useful for testing. +chain_to_process(chain2proc, I...) = chain2proc[I...] """ state_for_chain(state[, I...]) @@ -122,47 +154,49 @@ transition_for_process(state::TemperedState, I...) = state.transitions_and_state Return the state corresponding to the chain indexed by `I...`. If `I...` is not specified, the state corresponding to `β=1.0` will be returned. """ -state_for_chain(state::TemperedState) = state_for_chain(state, 1) -state_for_chain(state::TemperedState, I...) = state_for_process(state, chain_to_process(state, I...)) +state_for_chain(state) = state_for_chain(state, 1) +state_for_chain(state, I...) = state_for_process(state, chain_to_process(state, I...)) """ state_for_process(state, I...) Return the state corresponding to the process indexed by `I...`. """ -state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2] +state_for_process(state::SwapState, I...) = state_for_process(state.states, I...) +state_for_process(proc2state, I...) = proc2state[I...] """ - beta_for_chain(state[, I...]) + model_for_chain(ordering, sampler, model, state, I...) + +Return the model corresponding to the chain indexed by `I...`. -Return the β corresponding to the chain indexed by `I...`. -If `I...` is not specified, the β corresponding to `β=1.0` will be returned. +`ordering` specifies what sort of order the input models follow. """ -beta_for_chain(state::TemperedState) = beta_for_chain(state, 1) -beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...) -# NOTE: Array impl. is useful for testing. -beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...] +function model_for_chain end """ - beta_for_process(state, I...) + model_for_process(ordering, sampler, model, state, I...) + +Return the model corresponding to the process indexed by `I...`. -Return the β corresponding to the process indexed by `I...`. +`ordering` specifies what sort of order the input models follow. """ -beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.process_to_chain, I...) -# NOTE: Array impl. is useful for testing. -function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...) - return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...)) -end +function model_for_process end """ - getparams(transition) - getparams(::Type, transition) + models_by_processes(ordering, models, state) -Return the parameters contained in `transition`. +Return the models in the order of processes, assuming `models` is sorted according to `ordering`. + +See also: [`ProcessOrdering`](@ref), [`ChainOrdering`](@ref). +""" +models_by_processes(ordering, models, state) = sort_by_process(ordering, state, models) + +""" + samplers_by_processes(ordering, samplers, state) -If a type is specified, the parameters are returned in said type. +Return the `samplers` in the order of processes, assuming `samplers` is sorted according to `ordering`. -# Notes -This method is meant to be overloaded for the different transitions types. +See also: [`ProcessOrdering`](@ref), [`ChainOrdering`](@ref). """ -function getparams end +samplers_by_processes(ordering, samplers, state) = sort_by_process(ordering, state, samplers) diff --git a/src/stepping.jl b/src/stepping.jl index 83d35fa..3ea2ce1 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,134 +1,88 @@ -""" - should_swap(sampler, state) - -Return `true` if a swap should happen at this iteration, and `false` otherwise. -""" -function should_swap(sampler::TemperedSampler, state::TemperedState) - return state.total_steps % sampler.swap_every == 1 -end - -get_init_params(x, _)= x +get_init_params(x, _) = x get_init_params(init_params::Nothing, _) = nothing get_init_params(init_params::AbstractVector{<:Real}, _) = copy(init_params) get_init_params(init_params::AbstractVector{<:AbstractVector{<:Real}}, i) = init_params[i] +@concrete struct TemperedTransition + swaptransition + transition +end + +function transition_for_chain(transition::TemperedTransition, I...) + chain_idx = transition.swaptransition.chain_to_process[I...] + return transition.transition.transitions[chain_idx] +end + function AbstractMCMC.step( rng::Random.AbstractRNG, - model, + model::AbstractMCMC.AbstractModel, sampler::TemperedSampler; - N_burnin::Integer=0, - burnin_progress::Bool=AbstractMCMC.PROGRESS[], - init_params=nothing, kwargs... ) - - # `TemperedState` has the transitions and states in the order of - # the processes, and performs swaps by moving the (inverse) temperatures - # `β` between the processes, rather than moving states between processes - # and keeping the `β` local to each process. - # - # Therefore we iterate over the processes and then extract the corresponding - # `β`, `sampler` and `state`, and take a initialize. - transitions_and_states = [ - AbstractMCMC.step( - rng, - make_tempered_model(sampler, model, sampler.inverse_temperatures[i]), - getsampler(sampler, i); - init_params=get_init_params(init_params, i), - kwargs... - ) + # Create a `MultiSampler` and `MultiModel`. + multimodel = MultiModel([ + make_tempered_model(sampler, model, sampler.chain_to_beta[i]) for i in 1:numtemps(sampler) - ] + ]) + multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)]) + multistate = last(AbstractMCMC.step(rng, multimodel, multisampler; kwargs...)) # Make sure to collect, because we'll be using `setindex!(!)` later. - process_to_chain = collect(1:length(sampler.inverse_temperatures)) + process_to_chain = collect(1:length(sampler.chain_to_beta)) # Need to `copy` because this might be mutated. chain_to_process = copy(process_to_chain) - state = TemperedState( - transitions_and_states, - sampler.inverse_temperatures, - process_to_chain, + swapstate = SwapState( + multistate.states, chain_to_process, + process_to_chain, 1, - 0, - sampler.adaptation_states, - false, - Dict{Int,Float64}() + Dict{Int,Float64}(), ) - if N_burnin > 0 - AbstractMCMC.@ifwithprogresslogger burnin_progress name = "Burn-in" begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - if burnin_progress - threshold = N_burnin ÷ 200 - next_update = threshold - end - - for i in 1:N_burnin - if burnin_progress && i >= next_update - ProgressLogging.@logprogress i / N_burnin - next_update = i + threshold - end - state = no_swap_step(rng, model, sampler, state; kwargs...) - @set! state.burnin_steps += 1 - end - end - end - - return transition_for_chain(state), state + return AbstractMCMC.step(rng, model, sampler, TemperedState(swapstate, multistate, sampler.chain_to_beta)) end function AbstractMCMC.step( rng::Random.AbstractRNG, - model, + model::AbstractMCMC.AbstractModel, sampler::TemperedSampler, state::TemperedState; kwargs... ) - # Reset state - @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) - - if should_swap(sampler, state) - state = swap_step(rng, model, sampler, state) - @set! state.is_swap = true - else - state = no_swap_step(rng, model, sampler, state; kwargs...) - @set! state.is_swap = false - end + # Create the tempered `MultiModel`. + multimodel = MultiModel([make_tempered_model(sampler, model, beta) for beta in state.chain_to_beta]) + # Create the tempered `MultiSampler`. + # We're assuming the user has given the samplers in an order according to the initial models. + multisampler = MultiSampler(samplers_by_processes( + ChainOrder(), + [getsampler(sampler, i) for i in 1:numtemps(sampler)], + state.swapstate + )) + # Create the composition which applies `SwapSampler` first. + sampler_composition = multisampler ∘ swapsampler(sampler) + + # Step! + # NOTE: This will internally re-order the models according to processes before taking steps, + # hence the resulting transitions and states will be in the order of processes, as we desire. + transition_composition, state_composition = AbstractMCMC.step( + rng, + multimodel, + sampler_composition, + composition_state(sampler_composition, state.swapstate, state.state); + kwargs... + ) - @set! state.total_steps += 1 + # Construct the `TemperedTransition` and `TemperedState`. + swaptransition = inner_transition(transition_composition) + outertransition = outer_transition(transition_composition) - # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. - return transition_for_chain(state), state -end + swapstate = inner_state(state_composition) + outerstate = outer_state(state_composition) -function no_swap_step( - rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState; - kwargs... -) - # `TemperedState` has the transitions and states in the order of - # the processes, and performs swaps by moving the (inverse) temperatures - # `β` between the processes, rather than moving states between processes - # and keeping the `β` local to each process. - # - # Therefore we iterate over the processes and then extract the corresponding - # `β`, `sampler` and `state`, and take a step. - @set! state.transitions_and_states = [ - AbstractMCMC.step( - rng, - make_tempered_model(sampler, model, beta_for_process(state, i)), - sampler_for_process(sampler, state, i), - state_for_process(state, i); - kwargs... - ) - for i in 1:numtemps(sampler) - ] - - return state + return ( + TemperedTransition(swaptransition, outertransition), + TemperedState(swapstate, outerstate, state.chain_to_beta) + ) end """ @@ -142,8 +96,8 @@ is used. function swap_step( rng::Random.AbstractRNG, model, - sampler::TemperedSampler, - state::TemperedState + sampler, + state ) return swap_step(swapstrategy(sampler), rng, model, sampler, state) end @@ -151,31 +105,34 @@ end function swap_step( strategy::ReversibleSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Randomly select whether to attempt swaps between chains # corresponding to odd or even indices of the temperature ladder - odd = rand([true, false]) - for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))] - state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + odd = rand(rng, Bool) + # TODO: Use integer-division. + for k in [Int(2 * i - odd) for i in 1:(floor((length(model) - 1 + odd) / 2))] + state = swap_attempt(rng, model, sampler, state, k, k + 1) end return state end + function swap_step( strategy::NonReversibleSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state::SwapState # we're accessing `total_steps` restrict the type here ) # Alternate between attempting to swap chains corresponding # to odd and even indices of the temperature ladder - odd = state.total_steps % (2 * sampler.swap_every) != 0 - for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))] - state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + odd = state.total_steps % 2 != 0 + # TODO: Use integer-division. + for k in [Int(2 * i - odd) for i in 1:(floor((length(model) - 1 + odd) / 2))] + state = swap_attempt(rng, model, sampler, state, k, k + 1) end return state end @@ -183,45 +140,45 @@ end function swap_step( strategy::SingleSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Randomly pick one index `k` of the temperature ladder and # attempt a swap between the corresponding chain and its neighbour - k = rand(rng, 1:(numtemps(sampler) - 1)) - return swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + k = rand(rng, 1:(length(model) - 1)) + return swap_attempt(rng, model, sampler, state, k, k + 1) end function swap_step( strategy::SingleRandomSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Randomly pick two temperature ladder indices in order to # attempt a swap between the corresponding chains - chains = Set(1:numtemps(sampler)) + chains = Set(1:length(model)) i = pop!(chains, rand(rng, chains)) j = pop!(chains, rand(rng, chains)) - return swap_attempt(rng, model, sampler, state, i, j, sampler.adapt) + return swap_attempt(rng, model, sampler, state, i, j) end function swap_step( strategy::RandomSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Iterate through all of temperature ladder indices, picking random # pairs and attempting swaps between the corresponding chains - chains = Set(1:numtemps(sampler)) + chains = Set(1:length(model)) while length(chains) >= 2 i = pop!(chains, rand(rng, chains)) j = pop!(chains, rand(rng, chains)) - state = swap_attempt(rng, model, sampler, state, i, j, sampler.adapt) + state = swap_attempt(rng, model, sampler, state, i, j) end return state end @@ -230,8 +187,8 @@ function swap_step( strategy::NoSwap, rng::Random.AbstractRNG, model, - sampler::TemperedSampler, - state::TemperedState + sampler, + state ) return state -end \ No newline at end of file +end diff --git a/src/swapping.jl b/src/swapping.jl index 578aa9c..6bd8eae 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -79,11 +79,11 @@ this overrides and disables all swapping functionality. struct NoSwap <: AbstractSwapStrategy end """ - swap_betas!(chain_to_process, process_to_chain, i, j) + swap!(chain_to_process, process_to_chain, i, j) Swaps the `i`th and `j`th temperatures in place. """ -function swap_betas!(chain_to_process, process_to_chain, i, j) +function swap!(chain_to_process, process_to_chain, i, j) # TODO: Use BangBang's `@set!!` to also support tuples? # Extract the process index for each of the chains. process_for_chain_i, process_for_chain_j = chain_to_process[i], chain_to_process[j] @@ -113,8 +113,8 @@ and calls [`logdensity`](@ref) on the model returned from [`make_tempered_model` function compute_tempered_logdensities(model, sampler, transition, transition_other, β) tempered_model = make_tempered_model(sampler, model, β) return ( - logdensity(tempered_model, getparams(transition)), - logdensity(tempered_model, getparams(transition_other)) + logdensity(tempered_model, getparams(tempered_model, transition)), + logdensity(tempered_model, getparams(tempered_model, transition_other)) ) end function compute_tempered_logdensities( @@ -123,6 +123,19 @@ function compute_tempered_logdensities( return compute_tempered_logdensities(model, sampler, transition, transition_other, β) end +function compute_logdensities( + model::AbstractMCMC.AbstractModel, + model_other::AbstractMCMC.AbstractModel, + state, + state_other, +) + # TODO: Make use of `getparams_and_logprob` instead? + return ( + logdensity(model, getparams(model, state)), + logdensity(model, getparams(model_other, state_other)) + ) +end + """ swap_acceptance_pt(logπi, logπj) @@ -134,50 +147,3 @@ function swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) return (logπjθi + logπiθj) - (logπiθi + logπjθj) end - -""" - swap_attempt(rng, model, sampler, state, i, j) - -Attempt to swap the temperatures of two chains by tempering the densities and -calculating the swap acceptance ratio; then swapping if it is accepted. -""" -function swap_attempt(rng, model, sampler, state, i, j, adapt) - # Extract the relevant transitions. - sampler_i = sampler_for_chain(sampler, state, i) - sampler_j = sampler_for_chain(sampler, state, j) - transition_i = transition_for_chain(state, i) - transition_j = transition_for_chain(state, j) - state_i = state_for_chain(state, i) - state_j = state_for_chain(state, j) - β_i = beta_for_chain(state, i) - β_j = beta_for_chain(state, j) - # Evaluate logdensity for both parameters for each tempered density. - logπiθi, logπiθj = compute_tempered_logdensities( - model, sampler_i, sampler_j, transition_i, transition_j, state_i, state_j, β_i, β_j - ) - logπjθj, logπjθi = compute_tempered_logdensities( - model, sampler_j, sampler_i, transition_j, transition_i, state_j, state_i, β_j, β_i - ) - - # If the proposed temperature swap is accepted according `logα`, - # swap the temperatures for future steps. - logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) - should_swap = -Random.randexp(rng) ≤ logα - if should_swap - swap_betas!(state.chain_to_process, state.process_to_chain, i, j) - end - - # Keep track of the (log) acceptance ratios. - state.swap_acceptance_ratios[i] = logα - - # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is - # adapted before a new `inverse_temperatures` is generated and returned. - if adapt - ρs = adapt!!( - state.adaptation_states, state.chain_to_beta, i, min(one(logα), exp(logα)) - ) - @set! state.adaptation_states = ρs - @set! state.chain_to_beta = update_inverse_temperatures(ρs, state.chain_to_beta) - end - return state -end diff --git a/src/swapsampler.jl b/src/swapsampler.jl new file mode 100644 index 0000000..6e43ff1 --- /dev/null +++ b/src/swapsampler.jl @@ -0,0 +1,224 @@ +""" + SwapSampler <: AbstractMCMC.AbstractSampler + +# Fields +$(FIELDS) +""" +struct SwapSampler{S} <: AbstractMCMC.AbstractSampler + "swap strategy to use" + strategy::S +end + +SwapSampler() = SwapSampler(ReversibleSwap()) + +swapstrategy(sampler::SwapSampler) = sampler.strategy + +# Interaction with the state. +# NOTE: `SwapSampler` should only every interact with `ProcessOrdering`, so we don't implement `ChainOrdering`. +function model_for_chain(ordering::ProcessOrder, sampler::SwapSampler, model::MultiModel, state::SwapState, I...) + # `model` is expected to be ordered according to process index, hence we map chain index to process index + # and extract the model corresponding to said process. + return model_for_process(ordering, sampler, model, state, chain_to_process(state, I...)) +end + +function model_for_process(::ProcessOrder, sampler::SwapSampler, model::MultiModel, state::SwapState, I...) + # `model` is expected to be ordered according to process index, hence we just extract the corresponding index. + return model.models[I...] +end + +""" + SwapTransition + +Transition type for tempered samplers. +""" +@concrete struct SwapTransition + chain_to_process + process_to_chain +end + +function composition_transition( + sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler}, + swaptransition::SwapTransition, + outertransition::MultipleTransitions +) + saveall(sampler) && return CompositionTransition(outertransition, swaptransition) + # Otherwise we have to re-order the transitions, since without the `swaptransition` there's + # no way to recover the true ordering of the transitions. + return MultipleTransitions(sort_by_chain(ProcessOrder(), swaptransition, outertransition.transitions)) +end + +# NOTE: This does not have an initial `step`! This is because we need +# states to work with before we can do anything. Hence it only makes +# sense to use this sampler in composition with other samplers. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model, + sampler::SwapSampler, + state::SwapState; + kwargs... +) + # Reset state + @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) + + # Perform a swap step. + state = swap_step(rng, model, sampler, state) + @set! state.total_steps += 1 + + # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. + # TODO: What should we return here? + return SwapTransition(deepcopy(state.chain_to_process), deepcopy(state.process_to_chain)), state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler}, + state; + kwargs... +) + # Reminder: a `swap` can be implemented in two different ways: + # + # 1. Swap the models and leave ordering of (sampler, state)-pair unchanged. + # 2. Swap (sampler, state)-pairs and leave ordering of models unchanged. + # + # (1) has the properties: + # + Easy to keep `outerstate` and `swapstate` in sync since their ordering is never changed. + # - Ordering of `outerstate` no longer corresponds to ordering of models, i.e. the returned + # `outerstate.states[i]` does no longer correspond to a state targeting `model.models[i]`. + # This will have to be adjusted in the `AbstractMCMC.bundle_samples` before, say, converting + # into a `MCMCChains.Chains`. + # + # (2) has the properties: + # + Returned `outertransition` (and `outerstate`, if we want) has the same ordering as the models, + # i.e. `outerstate.states[i]` now corresponds to `model.models[i]`! + # - Need to keep `outerstate` and `swapstate` in sync since their ordering now changes. + # - Need to also re-order `outersampler.samplers` :/ + # + # Here (as in, below) we go with option (1), i.e. re-order the `models`. + # A full `step` then is as follows: + # 1. Sort models according to index processes using the `swapstate` from previous iteration. + # 2. Take step with `swapsampler`. + # 3. Sort models _again_ according to index processes using the new `swapstate`, since we + # might have made a swap in (2). + # 4. Run multi-sampler. + + outersampler, swapsampler = outer_sampler(sampler), inner_sampler(sampler) + + # Get the states. + outerstate_prev, swapstate_prev = outer_state(state), inner_state(state) + + # Re-order the models. + chain2models = model.models # but keep the original chain → model around because we'll re-order again later + @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate_prev) + + # Step for the swap-sampler. + swaptransition, swapstate = AbstractMCMC.step( + rng, model, swapsampler, state_from(model, swapstate_prev, outerstate_prev); + kwargs... + ) + + # Re-order the models AGAIN, since we might have swapped some. + @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate) + + # Create the current state from `outerstate_prev` and `swapstate`, and `step` for `outersampler`.` + outertransition, outerstate = AbstractMCMC.step( + # HACK: We really need the `state_from` here despite the fact that `SwapSampler` does note + # change the `swapstates.states` itself, but we might require a re-computation of certain + # quantities from the `model`, which has now potentially been re-ordered (see above). + # NOTE: We do NOT do `state_from(model, outerstate_prev, swapstate)` because as of now, + # `swapstate` does not implement `getparams_and_logprob`. + rng, model, outersampler, state_from(model, outerstate_prev, outerstate_prev); + kwargs... + ) + + # TODO: Should we re-order the transitions? + # Currently, one has to re-order the `outertransition` according to `swaptransition` + # in the `bundle_samples`. Is this the right approach though? + # TODO: We should at least re-order transitions in the case where `saveall(sampler) == false`! + # In this case, we'll just return the transition without the swap-transition, hence making it + # impossible to reconstruct the actual ordering! + return ( + composition_transition(sampler, swaptransition, outertransition), + composition_state(sampler, swapstate, outerstate) + ) +end + +# NOTE: The default initial `step` for `CompositionSampler` simply calls the two different +# `step` methods, but since `SwapSampler` does not have such an implementation this will fail. +# Instead we overload the initial `step` for `CompositionSampler` involving `SwapSampler` to +# first take a `step` using the non-swapsampler and then construct `SwapState` from the resulting state. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler}; + kwargs... +) + # This should hopefully be a `MultipleStates` or something since we're working with a `MultiModel`. + state_outer_initial = last(AbstractMCMC.step(rng, model, outer_sampler(sampler); kwargs...)) + # NOTE: Since `SwapState` wraps a sequence of states from another sampler, we need `state_outer_initial` + # to initialize the `SwapState`. + state_inner_initial = SwapState(state_outer_initial) + + # Create the composition state, and take a full step. + state = composition_state(sampler, state_inner_initial, state_outer_initial) + return AbstractMCMC.step(rng, model, sampler, state; kwargs...) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:SwapSampler,<:AbstractMCMC.AbstractSampler}; + kwargs... +) + # This should hopefully be a `MultipleStates` or something since we're working with a `MultiModel`. + state_inner_initial = last(AbstractMCMC.step(rng, model, inner_sampler(sampler); kwargs...)) + # NOTE: Since `SwapState` wraps a sequence of states from another sampler, we need `state_outer_initial` + # to initialize the `SwapState`. + state_outer_initial = SwapState(state_inner_initial) + + # Create the composition state, and take a full step. + state = composition_state(sampler, state_inner_initial, state_outer_initial) + return AbstractMCMC.step(rng, model, sampler, state; kwargs...) +end + +@nospecialize function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:SwapSampler,<:SwapSampler}; + kwargs... +) + error("`SwapSampler` requires states from sampler other than `SwapSampler` to be initialized") +end + +""" + swap_attempt(rng, model, sampler, state, i, j) + +Attempt to swap the temperatures of two chains by tempering the densities and +calculating the swap acceptance ratio; then swapping if it is accepted. +""" +function swap_attempt(rng::Random.AbstractRNG, model::MultiModel, sampler::SwapSampler, state, i, j) + # Extract the relevant transitions. + state_i = state_for_chain(state, i) + state_j = state_for_chain(state, j) + # Evaluate logdensity for both parameters for each tempered density. + # NOTE: `SwapSampler` should only be working with models ordered according to `ProcessOrder`, + # never `ChainOrder`, hence why we have the below. + model_i = model_for_chain(ProcessOrder(), sampler, model, state, i) + model_j = model_for_chain(ProcessOrder(), sampler, model, state, j) + logπiθi, logπiθj = compute_logdensities(model_i, model_j, state_i, state_j) + logπjθj, logπjθi = compute_logdensities(model_j, model_i, state_j, state_i) + + # If the proposed temperature swap is accepted according `logα`, + # swap the temperatures for future steps. + logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) + should_swap = -Random.randexp(rng) ≤ logα + if should_swap + swap!(state.chain_to_process, state.process_to_chain, i, j) + end + + # Keep track of the (log) acceptance ratios. + state.swap_acceptance_ratios[i] = logα + + # TODO: Handle adaptation. + return state +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..17b9125 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,34 @@ +# TODO: Move. +chain_to_process(transition::SwapTransition, I...) = transition.chain_to_process[I...] + +""" + roundtrips(transitions) + +Return sequence of `(start_index, turnpoint_index, end_index)`-triples representing roundtrips. +""" +function roundtrips(transitions::AbstractVector{<:TemperedTransition}) + return roundtrips(map(Base.Fix2(getproperty, :swaptransition), transitions)) +end +function roundtrips(transitions::AbstractVector{<:SwapTransition}) + result = Tuple{Int,Int,Int}[] + start_index, turn_index = 1, nothing + for (i, t) in enumerate(transitions) + n = length(t.chain_to_process) + if isnothing(turn_index) + # Looking for the turn. + if chain_to_process(t, 1) == n + turn_index = i + end + else + # Looking for the return/end. + if chain_to_process(t, 1) == 1 + push!(result, (start_index, turn_index, i)) + # Reset. + start_index = i + turn_index = nothing + end + end + end + + return result +end diff --git a/test/Project.toml b/test/Project.toml index 310ee7f..492577c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,12 +3,18 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" @@ -20,6 +26,6 @@ Bijectors = "0.10" Distributions = "0.24, 0.25" LogDensityProblems = "2" LogDensityProblemsAD = "1" -MCMCChains = "5.5" +MCMCChains = "6" Turing = "0.24" julia = "1" diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl new file mode 100644 index 0000000..3cc9c8a --- /dev/null +++ b/test/abstractmcmc.jl @@ -0,0 +1,186 @@ +@testset "AbstractMCMC" begin + rng = Random.default_rng() + model = DistributionLogDensity(MvNormal(Ones(2), I)) + logdensity_model = LogDensityModel(model) + + spl = RWMH(MvNormal(Zeros(dimension(model)), I)) + @test spl isa AbstractMCMC.AbstractSampler + + @testset "CompositionSampler(.., saveall=$(saveall))" for saveall in (true, false, Val(true), Val(false)) + spl_composed = MCMCTempering.CompositionSampler(spl, spl, saveall) + + num_iters = 100 + # Taking two steps with `spl` should be equivalent to one step with `spl ∘ spl`. + # Use the same initial state. + state_initial = last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl)) + state_composed_initial = MCMCTempering.state_from( + logdensity_model, + last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl_composed)), + state_initial, + ) + + @test state_composed_initial isa MCMCTempering.CompositionState + + # Take two steps with `spl`. + rng = Random.MersenneTwister(42) + state = deepcopy(state_initial) + for _ = 1:num_iters + transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state) + transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state) + end + params, logp = MCMCTempering.getparams_and_logprob(logdensity_model, state) + + # Take one step with `spl ∘ spl`. + rng = Random.MersenneTwister(42) + state_composed = deepcopy(state_composed_initial) + for _ = 1:num_iters + transition, state_composed = AbstractMCMC.step(rng, logdensity_model, spl_composed, state_composed) + + # Make sure the state types stay consistent. + if MCMCTempering.saveall(spl_composed) + @test transition isa MCMCTempering.CompositionTransition + end + @test state_composed isa MCMCTempering.CompositionState + end + params_composed, logp_composed = MCMCTempering.getparams_and_logprob(logdensity_model, state_composed) + + # Check that the parameters and log probability are the same. + @test params == params_composed + @test logp == logp_composed + + # Make sure that `AbstractMCMC.sample` is good. + chain_composed = sample(logdensity_model, spl_composed, 2; progress=false, chain_type=MCMCChains.Chains) + chain = sample( + logdensity_model, spl, MCMCTempering.saveall(spl_composed) ? 4 : 2; + progress=false, chain_type=MCMCChains.Chains + ) + + # Should be the same length because the `SequentialTransitions` will be unflattened. + @test chain_composed isa MCMCChains.Chains + @test length(chain_composed) == length(chain) + end + + @testset "RepeatedSampler(..., saveall=$(saveall))" for saveall in (true, false, Val(true), Val(false)) + spl_repeated = MCMCTempering.RepeatedSampler(spl, 2, saveall) + @test spl_repeated isa MCMCTempering.RepeatedSampler + + # Taking two steps with `spl` should be equivalent to one step with `spl ∘ spl`. + # Use the same initial state. + state_initial = last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl)) + state_repeated_initial = MCMCTempering.state_from( + logdensity_model, + last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl_repeated)), + state_initial, + ) + + num_iters = 100 + + # Take two steps with `spl`. + rng = Random.MersenneTwister(42) + state = deepcopy(state_initial) + for _ = 1:num_iters + transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state) + transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state) + end + params, logp = MCMCTempering.getparams_and_logprob(logdensity_model, state) + + # Take one step with `spl ∘ spl`. + rng = Random.MersenneTwister(42) + state_repeated = deepcopy(state_repeated_initial) + for _ = 1:num_iters + transition, state_repeated = AbstractMCMC.step(rng, logdensity_model, spl_repeated, state_repeated) + + # Make sure the state types stay consistent. + if MCMCTempering.saveall(spl_repeated) + @test transition isa MCMCTempering.SequentialTransitions + @test state_repeated isa MCMCTempering.SequentialStates + else + @test state_repeated isa typeof(state_initial) + end + end + params_repeated, logp_repeated = MCMCTempering.getparams_and_logprob(logdensity_model, state_repeated) + + # Check that the parameters and log probability are the same. + @test params == params_repeated + @test logp == logp_repeated + + # Make sure that `AbstractMCMC.sample` is good. + chain_repeated = sample(logdensity_model, spl_repeated, 2; progress=false, chain_type=MCMCChains.Chains) + chain = sample( + logdensity_model, spl, MCMCTempering.saveall(spl_repeated) ? 4 : 2; + progress=false, chain_type=MCMCChains.Chains + ) + + # Should be the same length because the `SequentialTransitions` will be unflattened. + @test length(chain_repeated) == length(chain) + end + + @testset "MultiSampler" begin + spl_multi = spl × spl + @testset "$model_multi" for model_multi in [ + MCMCTempering.MultiModel((logdensity_model, logdensity_model)), # tuple + MCMCTempering.MultiModel([logdensity_model, logdensity_model]), # vector + MCMCTempering.MultiModel((m for m in [logdensity_model, logdensity_model])) # iterator + ] + + @test spl_multi isa MCMCTempering.MultiSampler + + num_iters = 100 + # Use the same initial state. + states_initial = map(model_multi.models) do model + last(AbstractMCMC.step(Random.default_rng(), model, spl)) + end + states_multi_initial = MCMCTempering.state_from( + model_multi, + last(AbstractMCMC.step(Random.default_rng(), model_multi, spl_multi)), + MCMCTempering.MultipleStates(states_initial), + ) + + # Taking a step with `spl_multi` on `multimodel` should be equivalent + # to stepping with the component samplers on the component models. + rng = Random.MersenneTwister(42) + rng_multi = Random.MersenneTwister(42) + states = deepcopy(states_initial) + state_multi = deepcopy(states_multi_initial) + for _ = 1:num_iters + state_multi = last(AbstractMCMC.step(rng_multi, model_multi, spl_multi, state_multi)) + states = map(model_multi.models, states) do model, state + last(AbstractMCMC.step(rng, model, spl, state)) + end + end + params_and_logp = map(Base.Fix1(MCMCTempering.getparams_and_logprob, logdensity_model), states) + params_multi, logp_multi = MCMCTempering.getparams_and_logprob(model_multi, state_multi) + + @test map(first, params_and_logp) == params_multi + @test map(last, params_and_logp) == logp_multi + end + end + + @testset "SwapSampler" begin + # SwapSampler without tempering (i.e. in a composition and using `MultiModel`, etc.) + init_params = [[5.0], [5.0]] + mdl1 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(4.9999, 1))) + mdl2 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(5.0001, 1))) + spl1 = RWMH(MvNormal(Zeros(dimension(mdl1)), I)) + spl2 = let σ² = 1e-2 + MALA(∇ -> MvNormal(σ² * ∇, 2σ² * I)) + end + swapspl = MCMCTempering.SwapSampler() + spl_full = (spl1 × spl2) ∘ swapspl + product_model = LogDensityModel(mdl1) × LogDensityModel(mdl2) + # Sample! + multisamples = sample(product_model, spl_full, 1000; init_params=init_params, progress=false) + # Extract the transitions corresponding to each of the models. + model_transitions = mapreduce(hcat, multisamples) do t + [MCMCTempering.outer_transition(t).transitions[MCMCTempering.inner_transition(t).process_to_chain]...] + end + # Make sure we actually got some swaps going and we were using different types of states + # for both models. + @test length(unique(typeof, model_transitions[1, :])) ≥ 1 + @test length(unique(typeof, model_transitions[2, :])) ≥ 1 + + # Check that means are roughly okay. + model_params = map(first ∘ MCMCTempering.getparams, model_transitions) + @test vec(mean(model_params; dims=2)) ≈ [5.0, 5.0] atol=0.2 + end +end diff --git a/test/compat.jl b/test/compat.jl index e85bd05..7052bb1 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -1,6 +1,28 @@ # AdvancedMH.jl -MCMCTempering.getparams(transition::AdvancedMH.Transition) = transition.params -MCMCTempering.getparams(transition::AdvancedMH.GradientTransition) = transition.params +MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp +function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp) + Setfield.@set! transition.params = params + Setfield.@set! transition.lp = lp + return transition +end +MCMCTempering.getparams_and_logprob(transition::AdvancedMH.GradientTransition) = transition.params, transition.lp +# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible. +function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.GradientTransition, params, lp) + # NOTE: We have to re-compute the gradient here because this will be used in the subsequent `step` for + # the MALA sampler. + return AdvancedMH.GradientTransition(params, AdvancedMH.logdensity_and_gradient(model, params)...) +end # AdvancedHMC.jl -MCMCTempering.getparams(t::AdvancedHMC.Transition) = t.z.θ +MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value +MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition) + +# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible. +function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp) + # NOTE: Need to recompute the gradient because it might be used in the next integration step. + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index e13411b..160cabf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,22 +1,4 @@ -using MCMCTempering -using Test -using Distributions -using AdvancedMH -using MCMCChains -using Bijectors -using LinearAlgebra -using AbstractMCMC -using LogDensityProblems: LogDensityProblems, logdensity, logdensity_and_gradient -using LogDensityProblemsAD -using ForwardDiff: ForwardDiff -using AdvancedMH: AdvancedMH -using AdvancedHMC: AdvancedHMC -using Turing: Turing, DynamicPPL - - -include("utils.jl") -include("compat.jl") - +include("setup.jl") """ test_and_sample_model(model, sampler, inverse_temperatures[, swap_strategy]; kwargs...) @@ -35,7 +17,8 @@ Several properties of the tempered sampler are tested before returning: # Keyword arguments - `num_iterations`: The number of iterations to run the sampler for. Defaults to `2_000`. -- `swap_every`: The number of iterations between each swap attempt. Defaults to `2`. +- `steps_per_swap`: The number of iterations between each swap attempt. Defaults to `1`. +- `adapt`: Whether to adapt the sampler. Defaults to `false`. - `adapt_target`: The target acceptance rate for the swaps. Defaults to `0.234`. - `adapt_rtol`: The relative tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.1`. - `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`. @@ -44,59 +27,63 @@ Several properties of the tempered sampler are tested before returning: - `init_params`: The initial parameters to use for the sampler. Defaults to `nothing`. - `param_names`: The names of the parameters in the chain; used to construct the resulting chain. Defaults to `missing`. - `progress`: Whether to show a progress bar. Defaults to `false`. -- `kwargs...`: Additional keyword arguments to pass to `MCMCTempering.tempered`. """ function test_and_sample_model( model, sampler, - inverse_temperatures, - swap_strategy=MCMCTempering.SingleSwap(); + inverse_temperatures; + swap_strategy=MCMCTempering.SingleSwap(), mean_swap_rate_bound=0.1, compare_mean_swap_rate=≥, num_iterations=2_000, - swap_every=2, + steps_per_swap=1, + adapt=false, adapt_target=0.234, adapt_rtol=0.1, adapt_atol=0.05, init_params=nothing, param_names=missing, progress=false, - kwargs... + minimum_roundtrips=nothing ) - # TODO: Remove this when no longer necessary. - num_iterations_tempered = Int(ceil(num_iterations * swap_every / (swap_every - 1))) - # Make the tempered sampler. sampler_tempered = tempered( sampler, inverse_temperatures; swap_strategy=swap_strategy, - swap_every=swap_every, + steps_per_swap=steps_per_swap, adapt_target=adapt_target, - kwargs... ) + @test sampler_tempered.swapstrategy == swap_strategy + @test MCMCTempering.swapsampler(sampler_tempered).strategy == swap_strategy + # Store the states. states_tempered = [] callback = StateHistoryCallback(states_tempered) # Sample. samples_tempered = AbstractMCMC.sample( - model, sampler_tempered, num_iterations_tempered; + model, sampler_tempered, num_iterations; callback=callback, progress=progress, init_params=init_params ) + if !isnothing(minimum_roundtrips) + # Make sure we've had at least some roundtrips. + @test length(MCMCTempering.roundtrips(samples_tempered)) ≥ minimum_roundtrips + end + # Let's make sure the process ↔ chain mapping is valid. numtemps = MCMCTempering.numtemps(sampler_tempered) - for state in states_tempered - for i = 1:numtemps + @test all(states_tempered) do state + all(1:numtemps) do i # These two should be inverses of each other. - @test MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i + MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i end end # Extract the states that were swapped. - states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered) + states_swapped = map(Base.Fix2(getproperty, :swapstate), states_tempered) # Swap acceptance ratios should be compared against the target acceptance in case of adaptation. swap_acceptance_ratios = mapreduce( collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), @@ -115,25 +102,39 @@ function test_and_sample_model( end # Extract the history of chain indices. - process_to_chain_history_list = map(states_tempered) do state + process_to_chain_history_list = map(states_swapped) do state state.process_to_chain end process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) # Check that the swapping has been done correctly. - process_to_chain_uniqueness = map(states_tempered) do state + process_to_chain_uniqueness = map(states_swapped) do state length(unique(state.process_to_chain)) == length(state.process_to_chain) end @test all(process_to_chain_uniqueness) - # For the currently implemented strategies, the index process should not move by more than 1. - @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + # For every strategy except `RandomSwap`, the index process should not move by more than 1. + if !(swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.RandomSwap}) + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + end - chain_to_process_uniqueness = map(states_tempered) do state + chain_to_process_uniqueness = map(states_swapped) do state length(unique(state.chain_to_process)) == length(state.chain_to_process) end @test all(chain_to_process_uniqueness) + # Compare the tempered sampler to the untempered sampler. + state_tempered = states_tempered[end] + chain_tempered = AbstractMCMC.bundle_samples( + # TODO: Just use the underlying chain? + samples_tempered, + MCMCTempering.maybe_wrap_model(model), + sampler_tempered, + state_tempered, + MCMCChains.Chains; + param_names=param_names + ) + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row # Some of the strategies performs multiple swaps in a swap-iteration, @@ -141,20 +142,12 @@ function test_and_sample_model( # i.e. only count non-zero elements in a row _once_. Hence the `min`. min(1, sum(abs, row)) end + + num_nonswap_steps_taken = length(chain_tempered) + @test num_nonswap_steps_taken == (num_iterations * steps_per_swap) @test compare_mean_swap_rate( sum(swap_success_indicators), - (num_iterations_tempered / swap_every) * mean_swap_rate_bound - ) - - # Compare the tempered sampler to the untempered sampler. - state_tempered = states_tempered[end] - chain_tempered = AbstractMCMC.bundle_samples( - samples_tempered[findall((!).(getproperty.(states_tempered, :is_swap)))], - MCMCTempering.maybe_wrap_model(model), - sampler_tempered.sampler, - MCMCTempering.state_for_chain(state_tempered), - MCMCChains.Chains; - param_names=param_names + (num_nonswap_steps_taken / steps_per_swap) * mean_swap_rate_bound ) return chain_tempered @@ -163,37 +156,29 @@ end function compare_chains( chain::MCMCChains.Chains, chain_tempered::MCMCChains.Chains; atol=1e-6, rtol=1e-6, - compare_std=true, compare_ess=true, + compare_ess_slack=0.5, # HACK: this is very low which is unnecessary in most cases, but it's too random isbroken=false ) - desc = describe(chain)[1].nt - desc_tempered = describe(chain_tempered)[1].nt + mean = to_dict(MCMCChains.mean(chain)) + mean_tempered = to_dict(MCMCChains.mean(chain_tempered)) # Compare the means. if isbroken - @test_broken desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol + @test_broken all(isapprox(mean[sym], mean_tempered[sym]; atol, rtol) for sym in keys(mean)) else - @test desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol - end - - # Compare the std. of the chains. - if compare_std - if isbroken - @test_broken desc.std ≈ desc_tempered.std atol = atol rtol = rtol - else - @test desc.std ≈ desc_tempered.std atol = atol rtol = rtol - end + @test all(isapprox(mean[sym], mean_tempered[sym]; atol, rtol) for sym in keys(mean)) end # Compare the ESS. if compare_ess - ess = MCMCChains.ess_rhat(chain).nt.ess - ess_tempered = MCMCChains.ess_rhat(chain_tempered).nt.ess + ess = to_dict(MCMCChains.ess(chain)) + ess_tempered = to_dict(MCMCChains.ess(chain_tempered)) + @info "" ess ess_tempered if isbroken - @test_broken all(ess_tempered .≥ ess) + @test_broken all(ess_tempered[sym] ≥ ess[sym] * compare_ess_slack for sym in keys(ess)) else - @test all(ess_tempered .≥ ess) + @test all(ess_tempered[sym] ≥ ess[sym] * compare_ess_slack for sym in keys(ess)) end end end @@ -213,7 +198,7 @@ end chain_to_beta = [1.0, 0.75, 0.5, 0.25] # Make swap chain 1 (now on process 1) ↔ chain 2 (now on process 2) - MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1, 2) + MCMCTempering.swap!(chain_to_process, process_to_chain, 1, 2) # Expected result: chain 1 is now on process 2, chain 2 is now on process 1. target_process_to_chain = [2, 1, 3, 4] @test process_to_chain[chain_to_process] == 1:length(process_to_chain) @@ -229,7 +214,7 @@ end end # Make swap chain 2 (now on process 1) ↔ chain 3 (now on process 3) - MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2, 3) + MCMCTempering.swap!(chain_to_process, process_to_chain, 2, 3) # Expected result: chain 3 is now on process 1, chain 2 is now on process 3. target_process_to_chain = [3, 1, 2, 4] @test process_to_chain[chain_to_process] == 1:length(process_to_chain) @@ -246,7 +231,7 @@ end end @testset "Simple MvNormal with no expected swaps" begin - num_iterations = 10_000 + num_iterations = 5_000 d = 1 model = DistributionLogDensity(MvNormal(ones(d), I)) @@ -259,25 +244,24 @@ end sampler_rwmh, [1.0, 1e-3], # extreme temperatures -> don't exect much swapping to occur num_iterations=num_iterations, - swap_every=2, adapt=false, - init_params = [[0.0], [1000.0]], # initialized far apart - # At most 1% of swaps should be successful. + init_params=[[0.0], [1000.0]], # initialized far apart + # At MOST 1% of swaps should be successful. mean_swap_rate_bound=0.01, compare_mean_swap_rate=≤, ) # `atol` is fairly high because we haven't run this for "too" long. - @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.2 + @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.3 end @testset "GMM 1D" begin - num_iterations = 10_000 + num_iterations = 1_000 model = DistributionLogDensity( MixtureModel(Normal, [(-3, 1.5), (3, 1.5), (15, 1.5), (90, 1.5)], [0.175, 0.25, 0.275, 0.3]) ) # Setup non-tempered. - sampler_rwmh = RWMH(MvNormal(0.1 * ones(1))) + sampler_rwmh = RWMH(MvNormal(0.1 * Diagonal(Ones(1)))) # Simple geometric ladder inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.95 .^ (0:20)) @@ -287,13 +271,15 @@ end model, sampler_rwmh, inverse_temperatures, + swap_strategy=MCMCTempering.NonReversibleSwap(), num_iterations=num_iterations, - swap_every=2, adapt=false, # At least 25% of swaps should be successful. mean_swap_rate_bound=0.25, compare_mean_swap_rate=≥, progress=false, + # Make sure we have _some_ roundtrips. + minimum_roundtrips=10, ) # # Compare the chains. @@ -302,8 +288,7 @@ end @testset "MvNormal 2D with different swap strategies" begin d = 2 - num_iterations = 20_000 - swap_every = 2 + num_iterations = 5_000 μ_true = [-5.0, 5.0] σ_true = [1.0, √(10.0)] @@ -331,17 +316,19 @@ end MCMCTempering.RandomSwap() ] - @testset "$(swapstrategy)" for swapstrategy in swapstrategies + @testset "$(swap_strategy)" for swap_strategy in swapstrategies chain_tempered = test_and_sample_model( model, sampler, inverse_temperatures, num_iterations=num_iterations, - swap_every=swap_every, - swapstrategy=swapstrategy, + swap_strategy=swap_strategy, adapt=false, + # Make sure we have _some_ roundtrips. + minimum_roundtrips=10, ) - compare_chains(chain, chain_tempered, rtol=0.1, compare_std=false, compare_ess=true) + + compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true) end end @@ -368,15 +355,14 @@ end end @testset "AdvancedHMC.jl" begin - num_iterations = 2_000 + num_iterations = 5_000 # Set up HMC smpler. initial_ϵ = 0.1 integrator = AdvancedHMC.Leapfrog(initial_ϵ) proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator) metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model)) - adaptor = AdvancedHMC.StanHMCAdaptor(AdvancedHMC.MassMatrixAdaptor(metric), AdvancedHMC.StepSizeAdaptor(0.8, integrator)) - sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, adaptor) + sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric) # Sample using HMC. samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false) @@ -386,28 +372,47 @@ end ) map_parameters!(b, chain_hmc) + # Make sure that we get the "same" result when only using the inverse temperature 1. + sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1]) + chain_tempered = sample( + model, sampler_tempered, num_iterations; + init_params=copy(init_params), + chain_type=MCMCChains.Chains, + param_names=param_names, + progress=false, + ) + map_parameters!(b, chain_tempered) + compare_chains( + chain_hmc, chain_tempered; + atol=0.2, + compare_ess=true, + isbroken=false + ) + # Sample using tempered HMC. chain_tempered = test_and_sample_model( model, sampler_hmc, - [1, 0.25, 0.1, 0.01], + [1, 0.75, 0.5, 0.25, 0.1, 0.01], swap_strategy=MCMCTempering.ReversibleSwap(), num_iterations=num_iterations, - swap_every=10, adapt=false, - mean_swap_rate_bound=0, + mean_swap_rate_bound=0.1, init_params=copy(init_params), param_names=param_names, progress=false ) map_parameters!(b, chain_tempered) - - # TODO: Make it not broken, i.e. produce reasonable results. - compare_chains(chain_hmc, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false) + compare_chains( + chain_hmc, chain_tempered; + atol=0.3, + compare_ess=true, + isbroken=false, + ) end - + @testset "AdvancedMH.jl" begin - num_iterations = 100_000 + num_iterations = 10_000 d = LogDensityProblems.dimension(model) # Set up MALA sampler. @@ -415,33 +420,51 @@ end sampler_mh = MALA(∇ -> MvNormal(σ² * ∇, 2σ² * I)) # Sample using MALA. - samples_mh = AbstractMCMC.sample( + chain_mh = AbstractMCMC.sample( model, sampler_mh, num_iterations; - init_params=copy(init_params), progress=false - ) - chain_mh = AbstractMCMC.bundle_samples( - samples_mh, MCMCTempering.maybe_wrap_model(model), sampler_mh, samples_mh[1], MCMCChains.Chains; - param_names=param_names + init_params=copy(init_params), + progress=false, + chain_type=MCMCChains.Chains, + param_names=param_names, ) map_parameters!(b, chain_mh) - # Sample using tempered MALA. + # Make sure that we get the "same" result when only using the inverse temperature 1. + sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1]) + chain_tempered = sample( + model, sampler_tempered, num_iterations; + init_params=copy(init_params), + chain_type=MCMCChains.Chains, + param_names=param_names, + progress=false, + ) + map_parameters!(b, chain_tempered) + compare_chains( + chain_mh, chain_tempered; + atol=0.2, + compare_ess=true, + isbroken=false, + ) + + # Sample using actual tempering. chain_tempered = test_and_sample_model( model, sampler_mh, [1, 0.9, 0.75, 0.5, 0.25, 0.1], swap_strategy=MCMCTempering.ReversibleSwap(), num_iterations=num_iterations, - swap_every=2, adapt=false, - mean_swap_rate_bound=0, + mean_swap_rate_bound=0.1, init_params=copy(init_params), param_names=param_names ) map_parameters!(b, chain_tempered) # Need a large atol as MH is not great on its own - compare_chains(chain_mh, chain_tempered, atol=0.4, compare_std=false, compare_ess=true, isbroken=false) + compare_chains(chain_mh, chain_tempered, atol=0.2, compare_ess=true, isbroken=false) end end + + include("abstractmcmc.jl") + include("simple_gaussian.jl") end diff --git a/test/setup.jl b/test/setup.jl new file mode 100644 index 0000000..90195ca --- /dev/null +++ b/test/setup.jl @@ -0,0 +1,22 @@ +using MCMCTempering +using Test +using Distributions +using AdvancedMH +using MCMCChains +using Bijectors +using LinearAlgebra +using FillArrays +using Setfield: Setfield +using AbstractMCMC: AbstractMCMC, LogDensityModel +using LogDensityProblems: LogDensityProblems, logdensity, logdensity_and_gradient, dimension +using LogDensityProblemsAD +using Random: Random +using ForwardDiff: ForwardDiff +using AdvancedMH: AdvancedMH +using AdvancedHMC: AdvancedHMC +using Turing: Turing, DynamicPPL + + +include("utils.jl") +include("test_utils.jl") +include("compat.jl") diff --git a/test/simple_gaussian.jl b/test/simple_gaussian.jl new file mode 100644 index 0000000..1ff19b8 --- /dev/null +++ b/test/simple_gaussian.jl @@ -0,0 +1,74 @@ +@testset "Simple tempered Gaussian (closed form)" begin + μ = Zeros(1) + inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.8 .^ (0:10)) + variances_true = inv.(inverse_temperatures) + std_true_dict = map(variances_true) do v + Dict(:param_1 => √v) + end + tempered_dists = [MvNormal(Zeros(1), I / β) for β in inverse_temperatures] + tempered_multimodel = MCMCTempering.MultiModel(map(LogDensityModel ∘ DistributionLogDensity, tempered_dists)) + + init_params = zeros(length(μ)) + + num_samples = 1_000 + num_burnin = num_samples ÷ 2 + thin = 10 + + # Samplers. + rwmh = RWMH(MvNormal(Zeros(1), I)) + rwmh_tempered = TemperedSampler(rwmh, inverse_temperatures) + rwmh_product = MCMCTempering.MultiSampler(Fill(rwmh, length(tempered_dists))) + rwmh_product_with_swap = rwmh_product ∘ MCMCTempering.SwapSampler() + + # Sample. + @testset "TemperedSampler" begin + chains_product = sample( + DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples; + init_params, + bundle_resolve_swaps=true, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end + + @testset "MultiSampler without swapping" begin + chains_product = sample( + tempered_multimodel, rwmh_product, num_samples; + init_params, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end + + @testset "MultiSampler with swapping (saveall=true)" begin + chains_product = sample( + tempered_multimodel, rwmh_product_with_swap, num_samples; + init_params, + bundle_resolve_swaps=true, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end + + @testset "MultiSampler with swapping (saveall=true)" begin + chains_product = sample( + tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples; + init_params, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end +end + diff --git a/test/test_utils.jl b/test/test_utils.jl new file mode 100644 index 0000000..d343a8e --- /dev/null +++ b/test/test_utils.jl @@ -0,0 +1,126 @@ +using MCMCDiagnosticTools, Statistics, DataFrames + +""" + to_dict(c::MCMCChains.Chains[, col::Symbol]) + +Return a dictionary mapping parameter names to the values in column `col` of `c`. + +# Arguments +- `c`: A `MCMCChains.Chains` object. +- `col`: The column to extract values from. Defaults to the first column that is not `:parameters`. +""" +to_dict(c::MCMCChains.ChainDataFrame) = to_dict(c, first(filter(!=(:parameters), keys(c.nt)))) +function to_dict(c::MCMCChains.ChainDataFrame, col::Symbol) + df = DataFrame(c) + return Dict(sym => df[findfirst(==(sym), df[:, :parameters]), col] for sym in df.parameters) +end + +""" + atol_for_chain(chain; significance=1e-3, kind=Statistics.mean) + +Return a dictionary of absolute tolerances for each parameter in `chain`, computed +as the confidence interval width for the mean of the parameter with `significance`. +""" +function atol_for_chain(chain; significance=1e-3, kind=Statistics.mean) + param_names = names(chain, :parameters) + # Can reject H0 if, say, `abs(mean(chain2) - mean(chain1)) > confidence_width`. + # Or alternatively, compare means but with `atol` set to the `confidence_width`. + # NOTE: Failure to reject, i.e. passing the tests, does not imply that the means are equal. + mcse = to_dict(MCMCChains.mcse(chain; kind), :mcse) + return Dict(sym => quantile(Normal(0, mcse[sym]), 1 - significance/2) for sym in param_names) +end + +thin_to(chain, n) = chain[1:length(chain) ÷ n:end] + +""" + test_means(chain, mean_true; kwargs...) + +Test that the mean of each parameter in `chain` is approximately `mean_true`. + +# Arguments +- `chain`: A `MCMCChains.Chains` object. +- `mean_true`: A `Real` or `AbstractDict` mapping parameter names to their true mean. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_means(chain::MCMCChains.Chains, mean_true::Real; kwargs...) + return test_means(chain, Dict(sym => mean_true for sym in names(chain, :parameters)); kwargs...) +end +function test_means(chain::MCMCChains.Chains, mean_true::AbstractDict; n=length(chain), kwargs...) + chain = thin_to(chain, n) + atol = atol_for_chain(chain; kwargs...) + @test all(isapprox(mean(chain[sym]), 0, atol=atol[sym]) for sym in names(chain, :parameters)) +end + +""" + test_std(chain, std_true; kwargs...) + +Test that the standard deviation of each parameter in `chain` is approximately `std_true`. + +# Arguments +- `chain`: A `MCMCChains.Chains` object. +- `std_true`: A `Real` or `AbstractDict` mapping parameter names to their true standard deviation. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_std(chain::MCMCChains.Chains, std_true::Real; kwargs...) + return test_std(chain, Dict(sym => std_true for sym in names(chain, :parameters)); kwargs...) +end +function test_std(chain::MCMCChains.Chains, std_true::AbstractDict; n=length(chain), kwargs...) + chain = thin_to(chain, n) + atol = atol_for_chain(chain; kind=Statistics.std, kwargs...) + @info "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)] + @test all(isapprox(std(chain[sym]), std_true[sym], atol=atol[sym]) for sym in names(chain, :parameters)) +end + +""" + test_std_monotonicity(chains; isbroken=false, kwargs...) + +Test that the standard deviation of each parameter in `chains` is monotonically increasing. + +# Arguments +- `chains`: A vector of `MCMCChains.Chains` objects. +- `isbroken`: If `true`, then the test will be marked as broken. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_std_monotonicity(chains::AbstractVector{<:MCMCChains.Chains}; isbroken::Bool=false, kwargs...) + param_names = names(first(chains), :parameters) + # We should technically use a Bonferroni-correction here, but whatever. + atols = [atol_for_chain(chain; kind=Statistics.std, kwargs...) for chain in chains] + stds = [Dict(sym => std(chain[sym]) for sym in param_names) for chain in chains] + + num_chains = length(chains) + lbs = [Dict(sym => stds[i][sym] - atols[i][sym] for sym in param_names) for i in 1:num_chains] + ubs = [Dict(sym => stds[i][sym] + atols[i][sym] for sym in param_names) for i in 1:num_chains] + + for i = 2:num_chains + for sym in param_names + # If the upper-bound of the current is smaller than the lower-bound of the previous, then + # we can reject the null hypothesis that they are orderd. + if isbroken + @test_broken ubs[i][sym] ≥ lbs[i - 1][sym] + else + @test ubs[i][sym] ≥ lbs[i - 1][sym] + end + end + end +end + +""" + test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-3, kwargs...) + +Test that the mean and standard deviation of each parameter in `chains` is approximately `mean_true` +and `std_true`, respectively. Also test that the standard deviation is monotonically increasing. + +# Arguments +- `chains`: A vector of `MCMCChains.Chains` objects. +- `mean_true`: A vector of `Real` or `AbstractDict` mapping parameter names to their true mean. +- `std_true`: A vector of `Real` or `AbstractDict` mapping parameter names to their true standard deviation. +- `significance`: The significance level of the test. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-4, kwargs...) + @testset "chain $i" for i = 1:length(chains) + test_means(chains[i], mean_true[i]; kwargs...) + test_std(chains[i], std_true[i]; kwargs...) + end + test_std_monotonicity(chains; significance=0.05) +end