Skip to content

Commit

Permalink
Introduction of compositions and products of samplers and models (#151)
Browse files Browse the repository at this point in the history
* added CompositionSampler, RepeatedSampler, MultiSampler together with
additional methods for meta-type samplers

* added LinearAlgebra as dep

* big update but now everything finally works

* added additional pass-on-methods for meta-samplers and moved the
bundle_samples to a more appropriate place

* renamed state_from_state to state_from and changed the ordering of the
args to be more reasonable

* added some missing methods and fixed a typo

* added model_for_chain and model_for_process similar to other utility
methods for interacting with the tempered state, etc.

* added todo

* moved bundling back to ordering of defintions

* added missing test dep

* increase number of steps for one of the tests

* specialize step for combination of RepeatedSampler and MultiSampler

* Update src/sampler.jl

Co-authored-by: Harrison Wilde <[email protected]>

* Introduction of `SwapSampler` + make `TemperedSampler`  a fancy version of `CompositionSampler` (#152)

* split the transitions and states field in TemperedState

* improved internals of CompositionSampler

* ongoing work

* added swap sampler

* added ordering specification and a TemperedComposition

* integrated work on TemperedComposition into TemperedSampler and
removed the former

* reorederd stuff so it actually works

* fixed bug in swapping computation

* added length implementation for MultiModel

* improved construct for TemperedSampler and added some convenience methods

* fixed bundle_samples for Chains and TemperedTransition

* fixed breaking bug in setparams_and_logprob!! for SwapState

* remove usage of adapted HMC in tests

* remove doubling of iterations when testing tempering

* fixed bugs with MALA and tempering

* relax atol a bit for HMC

* relax another atol

* TemperedComposition is now truly just a wrapper around a CompositionSampler

* added method for computing roundtrips

* fixed testing + added test for roundtrips

* added docs for roundtrips method

* added some tests for SwapSampler without tempering

* remove ordering from SwapSampler since it should only interact with ProcessOrdering

* simplified the sorting according to chains and processes

* added some comments

* some minor refactoring

* some refactoring + TemperedSampler now orders the samplers correctly

* remove expected_ordering and make ordering assumptions more explicit

* relax type-constraints in state_for_chain so it also works with TemperedState

* removed redundant implementations of swap_attempt

* rename swap_betas! to swap!

* moved swap_attempt as it now requires definition of SwapSampler

* removed unnecessary setparams_and_logprob!! that should never be hit
with the current codebase

* removed expected_order

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <[email protected]>

* removed unnecessary variable in tests

* Update src/sampler.jl

Co-authored-by: Harrison Wilde <[email protected]>

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <[email protected]>

* removed burn-in from step in prep for AbstractMCMC improvements

* remove getparams_and_logprob implementation for SwapState as it's
unclear what is the right approach

* split the transitions and states field in TemperedState

* improved internals of CompositionSampler

* ongoing work

* added swap sampler

* added ordering specification and a TemperedComposition

* integrated work on TemperedComposition into TemperedSampler and
removed the former

* reorederd stuff so it actually works

* fixed bug in swapping computation

* added length implementation for MultiModel

* improved construct for TemperedSampler and added some convenience methods

* fixed bundle_samples for Chains and TemperedTransition

* fixed breaking bug in setparams_and_logprob!! for SwapState

* remove usage of adapted HMC in tests

* remove doubling of iterations when testing tempering

* fixed bugs with MALA and tempering

* relax atol a bit for HMC

* relax another atol

* TemperedComposition is now truly just a wrapper around a CompositionSampler

* added method for computing roundtrips

* fixed testing + added test for roundtrips

* added docs for roundtrips method

* added some tests for SwapSampler without tempering

* remove ordering from SwapSampler since it should only interact with ProcessOrdering

* simplified the sorting according to chains and processes

* added some comments

* some minor refactoring

* some refactoring + TemperedSampler now orders the samplers correctly

* remove expected_ordering and make ordering assumptions more explicit

* relax type-constraints in state_for_chain so it also works with TemperedState

* removed redundant implementations of swap_attempt

* rename swap_betas! to swap!

* moved swap_attempt as it now requires definition of SwapSampler

* removed unnecessary setparams_and_logprob!! that should never be hit
with the current codebase

* removed expected_order

* removed unnecessary variable in tests

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <[email protected]>

* removed burn-in from step in prep for AbstractMCMC improvements

* remove getparams_and_logprob implementation for SwapState as it's
unclear what is the right approach

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <[email protected]>

* added CompositionTransition + quite a few bundle_samples with a
`bundle_resolve_swaps` kwarg to allow converting into chains more easily

* more samples

* reduce requirement for ess comparison for AHMC a bit

* significant improvements to the simple Gaussian example, now testing
using MCSE to get tolerances, etc. and small improvements to the rest
of the tests

* trying to debug these tests

* more debug

* fixed typy

* reduce significance even further

---------

Co-authored-by: Harrison Wilde <[email protected]>

---------

Co-authored-by: Harrison Wilde <[email protected]>
  • Loading branch information
torfjelde and HarrisonWilde authored Mar 11, 2023
1 parent 7cf05a4 commit c4c1641
Show file tree
Hide file tree
Showing 21 changed files with 1,835 additions and 388 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand Down
194 changes: 191 additions & 3 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,24 @@ using ProgressLogging: ProgressLogging
using ConcreteStructs: @concrete
using Setfield: @set, @set!

using MCMCChains: MCMCChains

using InverseFunctions

using DocStringExtensions

include("logdensityproblems.jl")
include("abstractmcmc.jl")
include("adaptation.jl")
include("swapping.jl")
include("state.jl")
include("swapsampler.jl")
include("sampler.jl")
include("sampling.jl")
include("ladders.jl")
include("stepping.jl")
include("model.jl")
include("utils.jl")

export tempered,
tempered_sample,
Expand All @@ -39,16 +44,199 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing
maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model

# Bundling.
# Bundling of non-tempered samples.
function bundle_nontempered_samples(
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{T};
kwargs...
) where {T}
# Create the same model and sampler as we do in the initial step for `TemperedSampler`.
multimodel = MultiModel([
make_tempered_model(sampler, model, sampler.chain_to_beta[i])
for i in 1:numtemps(sampler)
])
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
multitransitions = [
MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
for t in ts
]

return AbstractMCMC.bundle_samples(
multitransitions,
multimodel,
multisampler,
MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
T
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
ts::Vector{<:MultipleTransitions},
model::MultiModel,
sampler::MultiSampler,
state::MultipleStates,
# TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
::Type{Vector{MCMCChains.Chains}};
kwargs...
)
return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
end
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
chain_type::Type;
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
if bundle_resolve_swaps
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
end

# TODO: Do better?
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{MCMCChains.Chains};
kwargs...
)
# Extract the transitions ordered, which are ordered according to processes, according to the chains.
ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
return AbstractMCMC.bundle_samples(
ts_actual,
model,
sampler_for_chain(sampler, state, 1),
state_for_chain(state, 1),
MCMCChains.Chains;
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
::Type{T};
kwargs...
) where {T}
# In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
# the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
# TODO: Should we really always default to outer sampler?
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
::Type{Vector{T}};
kwargs...
) where {T}
if !saveall(sampler)
# In this case, we just use the `outer` for everything since this is the only
# transitions we're keeping around.
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end

# Otherwise, we don't know what to do.
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{T};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
!bundle_resolve_swaps && return ts

# Resolve the swaps.
sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

AbstractMCMC.bundle_samples(
ts, maybe_wrap_model(model), sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type;
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
!bundle_resolve_swaps && return ts

# Resolve the swaps (using the already implemented resolution in `composition_transition`
# for this particular sampler but without `saveall`).
sampler_without_saveall = @set sampler.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
model::AbstractMCMC.AbstractModel,
sampler::RepeatedSampler,
state,
::Type{MCMCChains.Chains};
kwargs...
)
return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...)
end

# Unflatten in the case of `SequentialTransitions`.
function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:SequentialTransitions},
model::AbstractMCMC.AbstractModel,
sampler::RepeatedSampler,
state::SequentialStates,
::Type{MCMCChains.Chains};
kwargs...
)
ts_actual = [t for tseq in ts for t in tseq.transitions]
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains;
kwargs...
)
end
Expand Down
106 changes: 106 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using Setfield
using AbstractMCMC: AbstractMCMC

import LinearAlgebra: ×

"""
getparams([model, ]state)
Get the parameters from the `state`.
Default implementation uses [`getparams_and_logprob`](@ref).
"""
getparams(state) = first(getparams_and_logprob(state))
getparams(model, state) = first(getparams_and_logprob(model, state))

"""
getlogprob([model, ]state)
Get the log probability of the `state`.
Default implementation uses [`getparams_and_logprob`](@ref).
"""
getlogprob(state) = last(getparams_and_logprob(state))
getlogprob(model, state) = last(getparams_and_logprob(model, state))

"""
getparams_and_logprob([model, ]state)
Return a vector of parameters from the `state`.
See also: [`setparams_and_logprob!!`](@ref).
"""
getparams_and_logprob(model, state) = getparams_and_logprob(state)

"""
setparams_and_logprob!!([model, ]state, params)
Set the parameters in the state to `params`, possibly mutating if it makes sense.
See also: [`getparams_and_logprob`](@ref).
"""
setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob)

"""
state_from(model, state_target, state_source[, transition_source, transition_target])
Return a new state similar to `state_target` but updated from `state_source`, which could be
a different type of state.
"""
function state_from(model, state_target, state_source, transition_target, transition_source)
return state_from(model, state_target, state_source)
end
function state_from(model, state_target, state_source)
params, logp = getparams_and_logprob(model, state_source)
return setparams_and_logprob!!(model, state_target, params, logp)
end

"""
SequentialTransitions
A `SequentialTransitions` object is a container for a sequence of transitions.
"""
struct SequentialTransitions{A}
transitions::A
end

# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
# last transition/state.
getparams_and_logprob(transitions::SequentialTransitions) = getparams_and_logprob(transitions.transitions[end])
function getparams_and_logprob(model, transitions::SequentialTransitions)
return getparams_and_logprob(model, transitions.transitions[end])
end

function setparams_and_logprob!!(transitions::SequentialTransitions, params, logprob)
return @set transitions.transitions[end] = setparams_and_logprob!!(transitions.transitions[end], params, logprob)
end
function setparams_and_logprob!!(model, transitions::SequentialTransitions, params, logprob)
return @set transitions.transitions[end] = setparams_and_logprob!!(model, transitions.transitions[end], params, logprob)
end

"""
SequentialStates
A `SequentialStates` object is a container for a sequence of states.
"""
struct SequentialStates{A}
states::A
end

# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
# last transition/state.
getparams_and_logprob(state::SequentialStates) = getparams_and_logprob(state.states[end])
getparams_and_logprob(model, state::SequentialStates) = getparams_and_logprob(model, state.states[end])

function setparams_and_logprob!!(state::SequentialStates, params, logprob)
return @set state.states[end] = setparams_and_logprob!!(state.states[end], params, logprob)
end
function setparams_and_logprob!!(model, state::SequentialStates, params, logprob)
return @set state.states[end] = setparams_and_logprob!!(model, state.states[end], params, logprob)
end

# Includes.
include("samplers/composition.jl")
include("samplers/repeated.jl")
include("samplers/multi.jl")

2 changes: 1 addition & 1 deletion src/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and
"""
struct Geometric end

defaultscale(::Geometric, Δ) = eltype(Δ)(0.9)
defaultscale(::Geometric, Δ) = float(eltype))(0.9)

"""
InverselyAdditive
Expand Down
4 changes: 1 addition & 3 deletions src/ladders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ end
Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
"""
function check_inverse_temperatures(Δ)
if length(Δ) <= 1
error("More than one inverse temperatures must be provided.")
end
!isempty(Δ) || error("Inverse temperatures array is empty.")
if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].")
end
Expand Down
Loading

0 comments on commit c4c1641

Please sign in to comment.