From fe2157fc63671f67b5ad763f189967f2e190bb30 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Mar 2023 15:14:30 +0000 Subject: [PATCH] Documentation (#153) * added CompositionSampler, RepeatedSampler, MultiSampler together with additional methods for meta-type samplers * added LinearAlgebra as dep * big update but now everything finally works * added additional pass-on-methods for meta-samplers and moved the bundle_samples to a more appropriate place * renamed state_from_state to state_from and changed the ordering of the args to be more reasonable * added some missing methods and fixed a typo * added model_for_chain and model_for_process similar to other utility methods for interacting with the tempered state, etc. * added todo * moved bundling back to ordering of defintions * added missing test dep * increase number of steps for one of the tests * specialize step for combination of RepeatedSampler and MultiSampler * Update src/sampler.jl Co-authored-by: Harrison Wilde * Introduction of `SwapSampler` + make `TemperedSampler` a fancy version of `CompositionSampler` (#152) * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed unnecessary variable in tests * Update src/sampler.jl Co-authored-by: Harrison Wilde * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * removed unnecessary variable in tests * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * Apply suggestions from code review Co-authored-by: Harrison Wilde * added CompositionTransition + quite a few bundle_samples with a `bundle_resolve_swaps` kwarg to allow converting into chains more easily * more samples * reduce requirement for ess comparison for AHMC a bit * significant improvements to the simple Gaussian example, now testing using MCSE to get tolerances, etc. and small improvements to the rest of the tests * trying to debug these tests * more debug * fixed typy * reduce significance even further --------- Co-authored-by: Harrison Wilde * added docs * added a getting started example with a simple GMM + added another natural impl of bundle_samples * removed now redundant compute_tempered_logdensities + added some docs on different meta-samplers * minor improvements to docstrings + removed reference to non-existent method * fixed typo * added deployment and actions for doc deployment * fixed issue with GR plotting and headless * fixed missing renamings * defer design docs * added TODO for later on --------- Co-authored-by: Harrison Wilde --- .github/workflows/Docs.yml | 33 +++++ .github/workflows/DocsPreviewCleanup.yml | 26 ++++ docs/.gitignore | 2 + docs/Project.toml | 10 ++ docs/make.jl | 14 +++ docs/src/api.md | 118 ++++++++++++++++++ docs/src/getting-started.md | 151 +++++++++++++++++++++++ docs/src/index.md | 5 + src/MCMCTempering.jl | 15 ++- src/model.jl | 2 +- src/sampler.jl | 17 +-- src/samplers/composition.jl | 10 +- src/sampling.jl | 4 +- src/swapping.jl | 42 +++---- src/swapsampler.jl | 7 ++ 15 files changed, 420 insertions(+), 36 deletions(-) create mode 100644 .github/workflows/Docs.yml create mode 100644 .github/workflows/DocsPreviewCleanup.yml create mode 100644 docs/.gitignore create mode 100644 docs/Project.toml create mode 100644 docs/make.jl create mode 100644 docs/src/api.md create mode 100644 docs/src/getting-started.md create mode 100644 docs/src/index.md diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml new file mode 100644 index 0000000..7306eb0 --- /dev/null +++ b/.github/workflows/Docs.yml @@ -0,0 +1,33 @@ +name: Documentation + +on: + push: + branches: + # Build the master branch. + - master + tags: '*' + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: '1' + - name: Install dependencies + run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + JULIA_DEBUG: Documenter # Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) + GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 + run: julia --project=docs/ docs/make.jl diff --git a/.github/workflows/DocsPreviewCleanup.yml b/.github/workflows/DocsPreviewCleanup.yml new file mode 100644 index 0000000..4f57bc4 --- /dev/null +++ b/.github/workflows/DocsPreviewCleanup.yml @@ -0,0 +1,26 @@ +name: DocsPreviewCleanup + +on: + pull_request: + types: [closed] + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v2 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "Documenter.jl" + git config user.email "documenter@juliadocs.github.io" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi + env: + PRNUM: ${{ github.event.number }} diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..a303fff --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,2 @@ +build/ +site/ diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 0000000..919f5bf --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,10 @@ +[deps] +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MCMCTempering = "ce233488-44ea-4441-b732-192676ce2298" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 0000000..8875bad --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,14 @@ +using Documenter +using MCMCTempering + +DocMeta.setdocmeta!(MCMCTempering, :DocTestSetup, :(using MCMCTempering); recursive=true) + +makedocs( + sitename = "MCMCTempering", + format = Documenter.HTML(), + modules = [MCMCTempering], + pages=["Home" => "index.md", "getting-started.md", "api.md"], +) + +# Deply! +deploydocs(; repo="github.com/TuringLang/MCMCTempering.jl.git", push_preview=true) diff --git a/docs/src/api.md b/docs/src/api.md new file mode 100644 index 0000000..161eeb3 --- /dev/null +++ b/docs/src/api.md @@ -0,0 +1,118 @@ +# API + +## Temper samplers + +```@docs +MCMCTempering.tempered +MCMCTempering.TemperedSampler +``` + +Under the hood, [`MCMCTempering.TemperedSampler`](@ref) is actually just a "fancy" representation of a composition (represented using a [`MCMCTempering.CompositionSampler`](@ref)) of a [`MultiSampler`](@ref) and a [`SwapSampler`](@ref). + +Roughly speaking, the implementation of `AbstractMCMC.step` for [`MCMCTempering.TemperedSampler`](@ref) is basically + +```julia +# 1. Construct the tempered models. +multimodel = MultiModel([make_tempered_model(model, β) for β in tempered_sampler.chain_to_beta]) +# 2. Construct the samplers (can be the same one repeated multiple times or different ones) +multisampler = MultiSampler([getsampler(tempered_sampler, i) for i = 1:numtemps]) +# 3. Step targeting `multimodel` using a compositoin of `multisampler` and `swapsampler`. +AbstractMCMC.step(rng, multimodel, multisampler ∘ swapsampler, state; kwargs...) +``` + +which in this case is provided by repeated calls to [`MCMCTempering.make_tempered_model`](@ref). + +```@docs +MCMCTempering.make_tempered_model +``` + +This should be overloaded if you have some custom model-type that does not support the LogDensityProblems.jl-interface. + +## Swapping + +Swapping is implemented using the somewhat special [`MCMCTempering.SwapSampler`](@ref) + +```@docs +MCMCTempering.SwapSampler +MCMCTempering.swapstrategy +``` + +!!! warning + This is a rather special sampler because, unlike most other implementations of `AbstractMCMC.AbstractSampler`, this is not a valid sampler _on its own_; for this to be sensible it needs to be part of composition (see [`MCMCTempering.CompositionSampler`](@ref)) with _at least_ one other type of (an actually valid) sampler. + +### Different swap-strategies + +A [`MCMCTempering.SwapSampler`](@ref) can be defined with different swapping strategies: + +```@docs +MCMCTempering.AbstractSwapStrategy +MCMCTempering.ReversibleSwap +MCMCTempering.NonReversibleSwap +MCMCTempering.SingleSwap +MCMCTempering.SingleRandomSwap +MCMCTempering.RandomSwap +MCMCTempering.NoSwap +``` + +```@docs +MCMCTempering.swap_step +``` + +## Other samplers + +```@docs +MCMCTempering.saveall +``` + +### Compositions of samplers +```@docs +MCMCTempering.CompositionSampler +``` + +This sampler also has its own transition- and state-type + +```@docs +MCMCTempering.CompositionTransition +MCMCTempering.CompositionState +``` + +#### Repeated sampler / composition with itself + +Large compositions can have unfortunate effects on the compilation times in Julia. + +To alleviate this issue we also have the [`RepeatedSampler`](@ref): + +```@docs +MCMCTempering.RepeatedSampler +``` + +In the case where [`saveall`](@ref) returns `false`, `step` for a [`MCMCTempering.RepeatedSampler`](@ref) simply returns the last transition and state; if it returns `true`, then the transition is of type [`MCMCTempering.SequentialTransitions`](@ref) and the state is of type [`MCMCTempering.SequentialStates`](@ref). + +```@docs +MCMCTempering.SequentialTransitions +MCMCTempering.SequentialStates +``` + +This effectively allows you to specify whether or not the "intermediate" states should be kept or not. + +!!! note + You will rarely see [`MCMCTempering.SequentialTransitions`](@ref) and [`MCMCTempering.SequentialStates`](@ref) as a user because `AbstractMCMC.bundle_samples` has been overloaded to these to return the flattened representation, i.e. we "un-roll" the transitions in every [`MCMCTempering.SequentialTransitions`](@ref). + +### Multiple or product of samplers + +```@docs +MCMCTempering.MultiSampler +``` + +where the tempered models are represented using a [`MCMCTempering.MultiModel`](@ref) + +```@docs +MCMCTempering.MultiModel +``` + +The `step` for a [`MCMCTempering.MultiSampler`](@ref) and a [`MCMCTempering.MultiModel`] is a transition of type [`MCMCTempering.MultipleTransitions`](@ref) and a state of type [`MCMCTempering.MultipleStates`](@ref) + +```@docs +MCMCTempering.MultipleTransitions +MCMCTempering.MultipleStates +``` diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md new file mode 100644 index 0000000..82b3c2e --- /dev/null +++ b/docs/src/getting-started.md @@ -0,0 +1,151 @@ +# Getting started + +## Mixture of Gaussians + +Suppose we have a mixture of Gaussians, e.g. something like + +```@example gmm +using Distributions +target_distribution = MixtureModel( + Normal, + [(-3, 1.5), (3, 1.5), (20, 1.5)], # parameters + [0.5, 0.3, 0.2] # weights +) +``` + +This is a simple 1-dimensional distribution, so let's visualize it: + +```@example gmm +using StatsPlots +figsize = (800, 400) +plot(target_distribution; components=false, label=nothing, size=figsize) +``` + +We can convert a `Distribution` from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) into something we can pass to `sample` for many different samplers by implementing the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface: + +```@example gmm +using LogDensityProblems: LogDensityProblems + +struct DistributionLogDensity{D} + d::D +end + +LogDensityProblems.logdensity(d::DistributionLogDensity, x) = loglikelihood(d.d, x) +LogDensityProblems.dimension(d::DistributionLogDensity) = length(d.d) +LogDensityProblems.capabilities(::Type{<:DistributionLogDensity}) = LogDensityProblems.LogDensityOrder{0}() + +# Wrap our target distribution. +target_model = DistributionLogDensity(target_distribution) +``` + +Immediately one might reach for a standard sampler, e.g. a random-walk Metropolis-Hastings (RWMH) from [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl) and start sampling using `sample`: + +```@example gmm +using AdvancedMH, MCMCChains, LinearAlgebra + +using StableRNGs +rng = StableRNG(42) # To ensure reproducbility across devices. + +sampler = RWMH(MvNormal(zeros(1), I)) +num_iterations = 10_000 +chain = sample( + rng, + target_model, sampler, num_iterations; + chain_type=MCMCChains.Chains, + param_names=["x"] +) +``` + +```@example gmm +plot(chain; size=figsize) +``` + +This doesn't look quite like what we're expecting. + +```@example gmm +plot(target_distribution; components=false, linewidth=2) +density!(chain) +plot!(size=figsize) +``` + +Notice how `chain` has zero probability mass in the left-most component of the mixture! + +Let's instead try to use a _tempered_ version of `RWMH`. _But_ before we do that, we need to make sure that AdvancedMH.jl is compatible with MCMCTempering.jl. + +To do that we need to implement two methods. First we need to tell MCMCTempering how to extract the parameters, and potentially the log-probabilities, from a `AdvancedMH.Transition`: + +```@docs +MCMCTempering.getparams_and_logprob +``` + +And similarly, we need a way to _update_ the parameters and the log-probabilities of a `AdvancedMH.Transition`: + +```@docs +MCMCTempering.setparams_and_logprob!! +``` + +Luckily, implementing these is quite easy: + +```@example gmm +using MCMCTempering + +MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp +function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp) + return AdvancedMH.Transition(params, lp) +end +``` + +Now that this is done, we can wrap `sampler` in a [`MCMCTempering.TemperedSampler`](@ref) + +```@example gmm +inverse_temperatures = 0.90 .^ (0:20) +sampler_tempered = TemperedSampler(sampler, inverse_temperatures) +``` + +aaaaand `sample`! + +```@example gmm +chain_tempered = sample( + rng, target_model, sampler_tempered, num_iterations; + chain_type=MCMCChains.Chains, + param_names=["x"] +) +``` + +Let's see how this looks + +```@example gmm +plot(chain_tempered) +plot!(size=figsize) +``` + +```@example gmm +plot(target_distribution; components=false, linewidth=2) +density!(chain) +density!(chain_tempered) +plot!(size=figsize) +``` + +Neato; we've indeed captured the target distribution much better! + +We can even inspect _all_ of the tempered chains if we so desire + +```@example gmm +chain_tempered_all = sample( + rng, + target_model, sampler_tempered, num_iterations; + chain_type=Vector{MCMCChains.Chains}, # Different! + param_names=["x"] +); +``` + +```@example gmm +plot(target_distribution; components=false, linewidth=2) +density!(chain) +# Tempered ones. +for chain_tempered in chain_tempered_all[2:end] + density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all)))) +end +density!(chain_tempered_all[1], color="green", size=figsize) +plot!(size=figsize) +``` diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 0000000..cd6fd19 --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,5 @@ +# MCMCTempering.jl + +*Tempering methods and more for Markov chain Monte Carlo methods.* + +MCMCTempering provides implementations of different ways to define tempered samplers and models, in addition to other ways of composing and mixing samplers. diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index cb658d5..f9e07cb 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -70,7 +70,8 @@ function bundle_nontempered_samples( multimodel, multisampler, MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)), - T + T; + kwargs... ) end @@ -98,6 +99,7 @@ function AbstractMCMC.bundle_samples( bundle_resolve_swaps::Bool=false, kwargs... ) where {T} + # TODO: Implement special one for `Vector{MCMCChains.Chains}`. if bundle_resolve_swaps return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...) end @@ -106,6 +108,17 @@ function AbstractMCMC.bundle_samples( return ts end +function AbstractMCMC.bundle_samples( + ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, + model::AbstractMCMC.AbstractModel, + sampler::TemperedSampler, + state::TemperedState, + ::Type{Vector{MCMCChains.Chains}}; + kwargs... +) + return bundle_nontempered_samples(ts, model, sampler, state, Vector{MCMCChains.Chains}; kwargs...) +end + function AbstractMCMC.bundle_samples( ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, model::AbstractMCMC.AbstractModel, diff --git a/src/model.jl b/src/model.jl index 5ac8ec9..7e3642b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -3,7 +3,7 @@ Return an instance representing a `model` tempered with `beta`. -The return-type depends on its usage in [`compute_tempered_logdensities`](@ref). +The return-type depends on its usage in [`compute_logdensities`](@ref). """ make_tempered_model(sampler, model, beta) = make_tempered_model(model, beta) function make_tempered_model(model, beta) diff --git a/src/sampler.jl b/src/sampler.jl index 7ecd097..3fdc253 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -113,21 +113,21 @@ numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta) """ tempered(sampler, inverse_temperatures; kwargs...) OR - tempered(sampler, N_it; swap_strategy=ReversibleSwap(), kwargs...) + tempered(sampler, num_temps; swap_strategy=ReversibleSwap(), kwargs...) Return a tempered version of `sampler` using the provided `inverse_temperatures` or -inverse temperatures generated from `N_it` and the `swap_strategy`. +inverse temperatures generated from `num_temps` and the `swap_strategy`. # Arguments - `sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to - The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: - `inverse_temperatures` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 OR - - `N_it`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures` + - `num_temps`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures` # Keyword arguments - `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains -- `swap_every::Integer` steps are carried out between each attempt at a swap +- `steps_per_swap::Integer` steps are carried out between each attempt at a swap # See also - [`TemperedSampler`](@ref) @@ -142,17 +142,20 @@ inverse temperatures generated from `N_it` and the `swap_strategy`. """ function tempered( sampler::AbstractMCMC.AbstractSampler, - N_it::Integer; + num_temps::Integer; swap_strategy::AbstractSwapStrategy=ReversibleSwap(), kwargs... ) - return tempered(sampler, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy = swap_strategy, kwargs...) + return tempered( + sampler, generate_inverse_temperatures(num_temps, swap_strategy); + swap_strategy = swap_strategy, + kwargs... + ) end function tempered( sampler::AbstractMCMC.AbstractSampler, inverse_temperatures::Vector{<:Real}; swap_strategy::AbstractSwapStrategy=ReversibleSwap(), - # TODO: Change `swap_every` to something like `number_of_iterations_per_swap`. steps_per_swap::Integer=1, adapt::Bool=false, adapt_target::Real=0.234, diff --git a/src/samplers/composition.jl b/src/samplers/composition.jl index b8fb153..5e20789 100644 --- a/src/samplers/composition.jl +++ b/src/samplers/composition.jl @@ -36,7 +36,7 @@ saveall(::CompositionSampler{<:Any,<:Any,Val{SaveAll}}) where {SaveAll} = SaveAl """ CompositionState -A `CompositionState` is a container for a sequence of states. +Wrapper around the inner and outer states obtained from a [`CompositionSampler`](@ref). # Fields $(FIELDS) @@ -58,6 +58,14 @@ function setparams_and_logprob!!(model, state::CompositionState, params, logprob return @set state.state_outer = setparams_and_logprob!!(model, state.state_outer, params, logprob) end +""" + CompositionTransition + +Wrapper around the inner and outer transitions obtained from a [`CompositionSampler`](@ref). + +# Fields +$(FIELDS) +""" struct CompositionTransition{S1,S2} "The outer transition" transition_outer::S1 diff --git a/src/sampling.jl b/src/sampling.jl index 9f51ec0..3a98e85 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -9,7 +9,7 @@ Provide either `inverse_temperatures` or `N_it` (and a `swap_strategy`) to gener # Keyword arguments - `N_burnin::Integer` burn-in steps will be carried out before any swapping between chains is attempted - `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains -- `swap_every::Integer` steps are carried out between each attempt at a swap +- `steps_per_swap::Integer` steps are carried out between each attempt at a swap # See also - [`tempered`](@ref) @@ -55,4 +55,4 @@ function tempered_sample( ) tempered_sampler = tempered(sampler, inverse_temperatures; kwargs...) return AbstractMCMC.sample(rng, model, tempered_sampler, N; kwargs...) -end \ No newline at end of file +end diff --git a/src/swapping.jl b/src/swapping.jl index 6bd8eae..330756d 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -14,7 +14,6 @@ Stochastically attempt either even- or odd-indexed swap moves between chains. See [^SYED19] for more on this approach, referred to as SEO in their paper. -# References [^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). """ struct ReversibleSwap <: AbstractSwapStrategy end @@ -29,7 +28,6 @@ swaps between neighbors. See [^SYED19] for more on this approach, referred to as DEO in their paper. -# References [^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). """ struct NonReversibleSwap <: AbstractSwapStrategy end @@ -43,7 +41,6 @@ At every swap step taken, this strategy samples a single chain index This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05] -# References [^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). """ struct SingleSwap <: AbstractSwapStrategy end @@ -56,7 +53,6 @@ At every swap step taken, this strategy samples two chain indices This approach is shown to be effective for certain models in [^1]. -# References [^1]: Malcolm Sambridge, A Parallel Tempering algorithm for probabilistic sampling and multimodal optimization, Geophysical Journal International, Volume 196, Issue 1, January 2014, Pages 357–374, https://doi.org/10.1093/gji/ggt342 """ struct SingleRandomSwap <: AbstractSwapStrategy end @@ -101,27 +97,29 @@ end """ - compute_tempered_logdensities(model, sampler, transition, transition_other, β) - compute_tempered_logdensities(model, sampler, sampler_other, transition, transition_other, state, state_other, β, β_other) + compute_logdensities(model[, model_other], state, state_other) -Return `(logπ(transition, β), logπ(transition_other, β))` where `logπ(x, β)` denotes the -log-density for `model` with inverse-temperature `β`. +Return `(logdensity(model, state), logdensity(model, state_other))`. -The default implementation extracts the parameters from the transitions using [`getparams`](@ref) -and calls [`logdensity`](@ref) on the model returned from [`make_tempered_model`](@ref). +The default implementation extracts the parameters from the transitions using [`getparams`](@ref). + +`model_other` can be provided to allow specializations that might be more efficient +if we know that `state_other` is from `model_other`, e.g. in the case where the log-probability +field is already present in `state` and `state_other`, and the only difference between +`logdensity(model, state_other)` and `logdensity(model_other, state_other)` is an easily computable +factor, then this can be exploited instead of re-computing the log-densities for both. """ -function compute_tempered_logdensities(model, sampler, transition, transition_other, β) - tempered_model = make_tempered_model(sampler, model, β) +function compute_logdensities( + model::AbstractMCMC.AbstractModel, + state, + state_other, +) + # TODO: Make use of `getparams_and_logprob` instead? At least for the `(model, state)` pair? return ( - logdensity(tempered_model, getparams(tempered_model, transition)), - logdensity(tempered_model, getparams(tempered_model, transition_other)) + logdensity(model, getparams(model, state)), + logdensity(model, getparams(model, state_other)) ) end -function compute_tempered_logdensities( - model, sampler, sampler_other, transition, transition_other, state, state_other, β, β_other -) - return compute_tempered_logdensities(model, sampler, transition, transition_other, β) -end function compute_logdensities( model::AbstractMCMC.AbstractModel, @@ -129,11 +127,7 @@ function compute_logdensities( state, state_other, ) - # TODO: Make use of `getparams_and_logprob` instead? - return ( - logdensity(model, getparams(model, state)), - logdensity(model, getparams(model_other, state_other)) - ) + return compute_logdensities(model, state, state_other) end """ diff --git a/src/swapsampler.jl b/src/swapsampler.jl index 6e43ff1..4e173ad 100644 --- a/src/swapsampler.jl +++ b/src/swapsampler.jl @@ -11,6 +11,11 @@ end SwapSampler() = SwapSampler(ReversibleSwap()) +""" + swapstrategy(sampler::SwapSampler) + +Return the swap-strategy used by `sampler`. +""" swapstrategy(sampler::SwapSampler) = sampler.strategy # Interaction with the state. @@ -112,6 +117,8 @@ function AbstractMCMC.step( @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate_prev) # Step for the swap-sampler. + # TODO: We should probably call `state_from(model, model_other, state, state_other)` so we + # can avoid additional log-joint computations, gradient commputations, etc. swaptransition, swapstate = AbstractMCMC.step( rng, model, swapsampler, state_from(model, swapstate_prev, outerstate_prev); kwargs...