Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduction of SwapSampler + make TemperedSampler a fancy version of CompositionSampler #152

Merged
merged 90 commits into from
Mar 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
6d8e752
split the transitions and states field in TemperedState
torfjelde Mar 3, 2023
9dcc810
improved internals of CompositionSampler
torfjelde Mar 3, 2023
767d559
ongoing work
torfjelde Mar 4, 2023
58c0376
added swap sampler
torfjelde Mar 4, 2023
0487135
added ordering specification and a TemperedComposition
torfjelde Mar 6, 2023
3e9dbe4
integrated work on TemperedComposition into TemperedSampler and
torfjelde Mar 6, 2023
d7b8096
reorederd stuff so it actually works
torfjelde Mar 6, 2023
86866ac
fixed bug in swapping computation
torfjelde Mar 6, 2023
1006fd8
added length implementation for MultiModel
torfjelde Mar 6, 2023
4f2e20c
improved construct for TemperedSampler and added some convenience met…
torfjelde Mar 6, 2023
8c7a9fd
fixed bundle_samples for Chains and TemperedTransition
torfjelde Mar 6, 2023
53b7df8
fixed breaking bug in setparams_and_logprob!! for SwapState
torfjelde Mar 6, 2023
4ca60ed
remove usage of adapted HMC in tests
torfjelde Mar 6, 2023
92e54de
remove doubling of iterations when testing tempering
torfjelde Mar 6, 2023
8bc5872
fixed bugs with MALA and tempering
torfjelde Mar 7, 2023
940332d
relax atol a bit for HMC
torfjelde Mar 7, 2023
da47bbc
relax another atol
torfjelde Mar 7, 2023
08cf069
TemperedComposition is now truly just a wrapper around a CompositionS…
torfjelde Mar 8, 2023
56e709a
added method for computing roundtrips
torfjelde Mar 8, 2023
e4a15ec
fixed testing + added test for roundtrips
torfjelde Mar 8, 2023
3371002
added docs for roundtrips method
torfjelde Mar 8, 2023
a1e4b7d
added some tests for SwapSampler without tempering
torfjelde Mar 8, 2023
4956fd7
remove ordering from SwapSampler since it should only interact with P…
torfjelde Mar 8, 2023
70f5d8c
simplified the sorting according to chains and processes
torfjelde Mar 8, 2023
a11f1ee
added some comments
torfjelde Mar 8, 2023
ee38580
some minor refactoring
torfjelde Mar 8, 2023
18f8600
some refactoring + TemperedSampler now orders the samplers correctly
torfjelde Mar 9, 2023
8d49045
remove expected_ordering and make ordering assumptions more explicit
torfjelde Mar 9, 2023
fd70d0e
relax type-constraints in state_for_chain so it also works with Tempe…
torfjelde Mar 9, 2023
dac7b06
removed redundant implementations of swap_attempt
torfjelde Mar 9, 2023
2097bb1
rename swap_betas! to swap!
torfjelde Mar 9, 2023
6ceacff
moved swap_attempt as it now requires definition of SwapSampler
torfjelde Mar 9, 2023
b06ddcf
removed unnecessary setparams_and_logprob!! that should never be hit
torfjelde Mar 9, 2023
c7c8f63
removed expected_order
torfjelde Mar 9, 2023
1715eea
Apply suggestions from code review
torfjelde Mar 9, 2023
8c42b82
removed unnecessary variable in tests
torfjelde Mar 9, 2023
a411362
Merge branch 'torfjelde/tempered-sampler-rewamp' of github.com:Turing…
torfjelde Mar 9, 2023
ef97a94
Update src/sampler.jl
torfjelde Mar 9, 2023
ed804c6
Apply suggestions from code review
torfjelde Mar 10, 2023
762fb3b
removed burn-in from step in prep for AbstractMCMC improvements
torfjelde Mar 10, 2023
7883f2a
remove getparams_and_logprob implementation for SwapState as it's
torfjelde Mar 10, 2023
a9bb0de
Merge branch 'torfjelde/tempered-sampler-rewamp' of github.com:Turing…
torfjelde Mar 10, 2023
cf0b27e
split the transitions and states field in TemperedState
torfjelde Mar 3, 2023
96f76b6
improved internals of CompositionSampler
torfjelde Mar 3, 2023
8e9af89
ongoing work
torfjelde Mar 4, 2023
d9424cc
added swap sampler
torfjelde Mar 4, 2023
25d3518
added ordering specification and a TemperedComposition
torfjelde Mar 6, 2023
61d29b2
integrated work on TemperedComposition into TemperedSampler and
torfjelde Mar 6, 2023
a4c2815
reorederd stuff so it actually works
torfjelde Mar 6, 2023
cf166d0
fixed bug in swapping computation
torfjelde Mar 6, 2023
d444975
added length implementation for MultiModel
torfjelde Mar 6, 2023
746f3cf
improved construct for TemperedSampler and added some convenience met…
torfjelde Mar 6, 2023
6df54a2
fixed bundle_samples for Chains and TemperedTransition
torfjelde Mar 6, 2023
2b627bd
fixed breaking bug in setparams_and_logprob!! for SwapState
torfjelde Mar 6, 2023
1b89157
remove usage of adapted HMC in tests
torfjelde Mar 6, 2023
8d9e466
remove doubling of iterations when testing tempering
torfjelde Mar 6, 2023
a8e317a
fixed bugs with MALA and tempering
torfjelde Mar 7, 2023
71b39c7
relax atol a bit for HMC
torfjelde Mar 7, 2023
d3d044c
relax another atol
torfjelde Mar 7, 2023
1830c79
TemperedComposition is now truly just a wrapper around a CompositionS…
torfjelde Mar 8, 2023
4af0e67
added method for computing roundtrips
torfjelde Mar 8, 2023
63c2724
fixed testing + added test for roundtrips
torfjelde Mar 8, 2023
929573f
added docs for roundtrips method
torfjelde Mar 8, 2023
54f043a
added some tests for SwapSampler without tempering
torfjelde Mar 8, 2023
1458e64
remove ordering from SwapSampler since it should only interact with P…
torfjelde Mar 8, 2023
ebb402b
simplified the sorting according to chains and processes
torfjelde Mar 8, 2023
b27d8cf
added some comments
torfjelde Mar 8, 2023
6621ce7
some minor refactoring
torfjelde Mar 8, 2023
972c2b3
some refactoring + TemperedSampler now orders the samplers correctly
torfjelde Mar 9, 2023
4d9def9
remove expected_ordering and make ordering assumptions more explicit
torfjelde Mar 9, 2023
7115dad
relax type-constraints in state_for_chain so it also works with Tempe…
torfjelde Mar 9, 2023
05e8521
removed redundant implementations of swap_attempt
torfjelde Mar 9, 2023
e062ae3
rename swap_betas! to swap!
torfjelde Mar 9, 2023
b816459
moved swap_attempt as it now requires definition of SwapSampler
torfjelde Mar 9, 2023
9466fe0
removed unnecessary setparams_and_logprob!! that should never be hit
torfjelde Mar 9, 2023
442f1d1
removed expected_order
torfjelde Mar 9, 2023
8e25d63
removed unnecessary variable in tests
torfjelde Mar 9, 2023
c344ad6
Apply suggestions from code review
torfjelde Mar 9, 2023
6359e31
removed burn-in from step in prep for AbstractMCMC improvements
torfjelde Mar 10, 2023
1a437f4
remove getparams_and_logprob implementation for SwapState as it's
torfjelde Mar 10, 2023
262995a
Apply suggestions from code review
torfjelde Mar 10, 2023
49ea8db
Merge branch 'torfjelde/tempered-sampler-rewamp' of github.com:Turing…
torfjelde Mar 10, 2023
f99809e
added CompositionTransition + quite a few bundle_samples with a
torfjelde Mar 10, 2023
42294a1
more samples
torfjelde Mar 10, 2023
a0e2732
reduce requirement for ess comparison for AHMC a bit
torfjelde Mar 11, 2023
15ce587
significant improvements to the simple Gaussian example, now testing
torfjelde Mar 11, 2023
66ee3a7
trying to debug these tests
torfjelde Mar 11, 2023
7030fce
more debug
torfjelde Mar 11, 2023
3e88f9e
fixed typy
torfjelde Mar 11, 2023
6aa3d09
reduce significance even further
torfjelde Mar 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 137 additions & 16 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ include("abstractmcmc.jl")
include("adaptation.jl")
include("swapping.jl")
include("state.jl")
include("swapsampler.jl")
include("sampler.jl")
include("sampling.jl")
include("ladders.jl")
include("stepping.jl")
include("model.jl")
include("utils.jl")

export tempered,
tempered_sample,
Expand All @@ -43,21 +45,82 @@ maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensity
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model

# Bundling.
# TODO: Improve this, somehow.
# TODO: Move this to an extension.
# Bundling of non-tempered samples.
function bundle_nontempered_samples(
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{T};
kwargs...
) where {T}
# Create the same model and sampler as we do in the initial step for `TemperedSampler`.
multimodel = MultiModel([
make_tempered_model(sampler, model, sampler.chain_to_beta[i])
for i in 1:numtemps(sampler)
])
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
multitransitions = [
MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
for t in ts
]

return AbstractMCMC.bundle_samples(
multitransitions,
multimodel,
multisampler,
MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
T
)
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:MultipleTransitions},
model::MultiModel,
sampler::MultiSampler,
state::MultipleStates,
# TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
::Type{Vector{MCMCChains.Chains}};
kwargs...
)
return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
end
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
if bundle_resolve_swaps
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
end

# TODO: Do better?
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:TemperedTransition},
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{MCMCChains.Chains};
kwargs...
)
# Extract the transitions ordered, which are ordered according to processes, according to the chains.
ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
return AbstractMCMC.bundle_samples(
map(Base.Fix2(getproperty, :transition), filter(!Base.Fix2(getproperty, :is_swap), ts)), # Remove the swaps.
ts_actual,
model,
sampler_for_chain(sampler, state),
state_for_chain(state),
sampler_for_chain(sampler, state, 1),
state_for_chain(state, 1),
HarrisonWilde marked this conversation as resolved.
Show resolved Hide resolved
MCMCChains.Chains;
kwargs...
)
Expand All @@ -68,27 +131,85 @@ function AbstractMCMC.bundle_samples(
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
::Type{MCMCChains.Chains};
::Type{T};
kwargs...
)
) where {T}
# In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
# the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
# TODO: Should we really always default to outer sampler?
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, MCMCChains.Chains;
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# Unflatten in the case of `SequentialTransitions`
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:SequentialTransitions},
ts::Vector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::SequentialStates,
::Type{MCMCChains.Chains};
state::CompositionState,
::Type{Vector{T}};
kwargs...
)
ts_actual = [t for tseq in ts for t in tseq.transitions]
) where {T}
if !saveall(sampler)
# In this case, we just use the `outer` for everything since this is the only
# transitions we're keeping around.
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end

# Otherwise, we don't know what to do.
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{T};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
!bundle_resolve_swaps && return ts

# Resolve the swaps.
sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
!bundle_resolve_swaps && return ts

# Resolve the swaps (using the already implemented resolution in `composition_transition`
# for this particular sampler but without `saveall`).
sampler_without_saveall = @set sampler.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.states[end], MCMCChains.Chains;
ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and
"""
struct Geometric end

defaultscale(::Geometric, Δ) = eltype(Δ)(0.9)
defaultscale(::Geometric, Δ) = float(eltype(Δ))(0.9)
HarrisonWilde marked this conversation as resolved.
Show resolved Hide resolved

"""
InverselyAdditive
Expand Down
4 changes: 1 addition & 3 deletions src/ladders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ end
Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
"""
function check_inverse_temperatures(Δ)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
if length(Δ) <= 1
error("More than one inverse temperatures must be provided.")
end
!isempty(Δ) || error("Inverse temperatures array is empty.")
if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].")
end
Expand Down
107 changes: 80 additions & 27 deletions src/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
TemperedState

A state for a tempered sampler.

# Fields
$(FIELDS)
"""
@concrete struct TemperedState
"state for swap-sampler"
swapstate
"state for the main sampler"
state
"inverse temperature for each of the chains"
chain_to_beta
end

"""
TemperedSampler <: AbstractMCMC.AbstractSampler

Expand All @@ -7,44 +24,39 @@ A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Temp

$(FIELDS)
"""
@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler
Base.@kwdef struct TemperedSampler{SplT,A,SwapT,Adapt} <: AbstractMCMC.AbstractSampler
HarrisonWilde marked this conversation as resolved.
Show resolved Hide resolved
"sampler(s) used to target the tempered distributions"
sampler
sampler::SplT
"collection of inverse temperatures β; β[i] correponds i-th tempered model"
inverse_temperatures
"number of steps of `sampler` to take before proposing swaps"
swap_every
"the swap strategy that will be used when proposing swaps"
swap_strategy
# TODO: This should be replaced with `P` just being some `NoAdapt` type.
chain_to_beta::A
"strategy to use for swapping"
swapstrategy::SwapT=ReversibleSwap()
# TODO: Remove `adapt` and just consider `adaptation_states=nothing` as no adaptation.
"boolean flag specifying whether or not to adapt"
adapt
adapt=false
"adaptation parameters"
adaptation_states
adaptation_states::Adapt=nothing
HarrisonWilde marked this conversation as resolved.
Show resolved Hide resolved
end

swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy
TemperedSampler(sampler, chain_to_beta; kwargs...) = TemperedSampler(; sampler, chain_to_beta, kwargs...)

swapsampler(sampler::TemperedSampler) = SwapSampler(sampler.swapstrategy)

# TODO: Do we need this now?
getsampler(samplers, I...) = getindex(samplers, I...)
getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler
getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...)

"""
numsteps(sampler::TemperedSampler)

Return number of inverse temperatures used by `sampler`.
"""
numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures)
chain_to_process(state::TemperedState, I...) = chain_to_process(state.swapstate, I...)
process_to_chain(state::TemperedState, I...) = process_to_chain(state.swapstate, I...)

"""
sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...])
sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)

Return the sampler corresponding to the chain indexed by `I...`.
If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned.
"""
sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1)
function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
return getsampler(sampler.sampler, chain_to_process(state, I...))
return sampler_for_process(sampler, state, chain_to_process(state, I...))
end

"""
Expand All @@ -53,9 +65,51 @@ end
Return the sampler corresponding to the process indexed by `I...`.
"""
function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...)
return getsampler(sampler.sampler, I...)
return _sampler_for_process_temper(sampler.sampler, state, I...)
end

# If `sampler` is a `MultiSampler`, we assume it's ordered according to chains.
_sampler_for_process_temper(sampler::MultiSampler, state, I...) = sampler.samplers[process_to_chain(state, I...)]
# Otherwise, we just use the same sampler for everything.
_sampler_for_process_temper(sampler, state, I...) = sampler

# Defer extracting the corresponding state to the `swapstate`.
state_for_process(state::TemperedState, I...) = state_for_process(state.swapstate, I...)

# Here we make the model(s) using the temperatures.
function model_for_process(sampler::TemperedSampler, model, state::TemperedState, I...)
return make_tempered_model(sampler, model, beta_for_process(state, I...))
end

"""
beta_for_chain(state[, I...])

Return the β corresponding to the chain indexed by `I...`.
If `I...` is not specified, the β corresponding to `β=1.0` will be returned.
"""
beta_for_chain(state::TemperedState) = beta_for_chain(state, 1)
beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...)
# NOTE: Array impl. is useful for testing.
beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...]

"""
beta_for_process(state, I...)

Return the β corresponding to the process indexed by `I...`.
"""
beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.swapstate.process_to_chain, I...)
# NOTE: Array impl. is useful for testing.
function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...)
return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...))
end

"""
numsteps(sampler::TemperedSampler)

Return number of inverse temperatures used by `sampler`.
"""
numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta)

"""
tempered(sampler, inverse_temperatures; kwargs...)
OR
Expand Down Expand Up @@ -99,7 +153,7 @@ function tempered(
inverse_temperatures::Vector{<:Real};
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
# TODO: Change `swap_every` to something like `number_of_iterations_per_swap`.
swap_every::Integer=1,
steps_per_swap::Integer=1,
HarrisonWilde marked this conversation as resolved.
Show resolved Hide resolved
adapt::Bool=false,
adapt_target::Real=0.234,
adapt_stepsize::Real=1,
Expand All @@ -109,14 +163,13 @@ function tempered(
kwargs...
)
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
swap_every ≥ 1 || error("`swap_every` must take a positive integer value greater ≥1.")
steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.")
inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
adaptation_states = init_adaptation(
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
)
# NOTE: We just make a repeated sampler for `sampler_inner`.
# TODO: Generalize. Allow passing in a `MultiSampler`, etc.
sampler_inner = sampler^swap_every
# FIXME: Remove the hard-coded `2` for swap-every, and change `should_swap` acoordingly.
return TemperedSampler(sampler_inner, inverse_temperatures, 2, swap_strategy, adapt, adaptation_states)
sampler_inner = sampler^steps_per_swap
return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states)
end
Loading