diff --git a/Project.toml b/Project.toml
index b369ba1..c6be239 100644
--- a/Project.toml
+++ b/Project.toml
@@ -9,7 +9,9 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
 ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl
index a8a084d..cb658d5 100644
--- a/src/MCMCTempering.jl
+++ b/src/MCMCTempering.jl
@@ -9,19 +9,24 @@ using ProgressLogging: ProgressLogging
 using ConcreteStructs: @concrete
 using Setfield: @set, @set!
 
+using MCMCChains: MCMCChains
+
 using InverseFunctions
 
 using DocStringExtensions
 
 include("logdensityproblems.jl")
+include("abstractmcmc.jl")
 include("adaptation.jl")
 include("swapping.jl")
 include("state.jl")
+include("swapsampler.jl")
 include("sampler.jl")
 include("sampling.jl")
 include("ladders.jl")
 include("stepping.jl")
 include("model.jl")
+include("utils.jl")
 
 export tempered,
     tempered_sample,
@@ -39,16 +44,199 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing
 maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model
 maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model
 
+# Bundling.
+# Bundling of non-tempered samples.
+function bundle_nontempered_samples(
+    ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
+    model::AbstractMCMC.AbstractModel,
+    sampler::TemperedSampler,
+    state::TemperedState,
+    ::Type{T};
+    kwargs...
+) where {T}
+    # Create the same model and sampler as we do in the initial step for `TemperedSampler`.
+    multimodel = MultiModel([
+        make_tempered_model(sampler, model, sampler.chain_to_beta[i])
+        for i in 1:numtemps(sampler)
+    ])
+    multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
+    multitransitions = [
+        MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
+        for t in ts
+    ]
+
+    return AbstractMCMC.bundle_samples(
+        multitransitions,
+        multimodel,
+        multisampler,
+        MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
+        T
+    )
+end
+
 function AbstractMCMC.bundle_samples(
-    ts::AbstractVector,
+    ts::Vector{<:MultipleTransitions},
+    model::MultiModel,
+    sampler::MultiSampler,
+    state::MultipleStates,
+    # TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
+    ::Type{Vector{MCMCChains.Chains}};
+    kwargs...
+)
+    return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
+        AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
+    end
+end
+
+# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
+function AbstractMCMC.bundle_samples(
+    ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
     model::AbstractMCMC.AbstractModel,
     sampler::TemperedSampler,
     state::TemperedState,
-    chain_type::Type;
+    ::Type{Vector{T}};
+    bundle_resolve_swaps::Bool=false,
+    kwargs...
+) where {T}
+    if bundle_resolve_swaps
+        return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
+    end
+
+    # TODO: Do better?
+    return ts
+end
+
+function AbstractMCMC.bundle_samples(
+    ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
+    model::AbstractMCMC.AbstractModel,
+    sampler::TemperedSampler,
+    state::TemperedState,
+    ::Type{MCMCChains.Chains};
     kwargs...
 )
+    # Extract the transitions ordered, which are ordered according to processes, according to the chains.
+    ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
+    return AbstractMCMC.bundle_samples(
+        ts_actual,
+        model,
+        sampler_for_chain(sampler, state, 1),
+        state_for_chain(state, 1),
+        MCMCChains.Chains;
+        kwargs...
+    )
+end
+
+function AbstractMCMC.bundle_samples(
+    ts::AbstractVector,
+    model::AbstractMCMC.AbstractModel,
+    sampler::CompositionSampler,
+    state::CompositionState,
+    ::Type{T};
+    kwargs...
+) where {T}
+    # In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
+    # the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
+    ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
+    # TODO: Should we really always default to outer sampler?
+    return AbstractMCMC.bundle_samples(
+        ts_actual, model, sampler.sampler_outer, state.state_outer, T;
+        kwargs...
+    )
+end
+
+# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
+function AbstractMCMC.bundle_samples(
+    ts::Vector,
+    model::AbstractMCMC.AbstractModel,
+    sampler::CompositionSampler,
+    state::CompositionState,
+    ::Type{Vector{T}};
+    kwargs...
+) where {T}
+    if !saveall(sampler)
+        # In this case, we just use the `outer` for everything since this is the only
+        # transitions we're keeping around.
+        return AbstractMCMC.bundle_samples(
+            ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
+            kwargs...
+        )
+    end
+
+    # Otherwise, we don't know what to do.
+    return ts
+end
+
+function AbstractMCMC.bundle_samples(
+    ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
+    model::AbstractMCMC.AbstractModel,
+    sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
+    state::CompositionState{<:MultipleStates,<:SwapState},
+    ::Type{T};
+    bundle_resolve_swaps::Bool=false,
+    kwargs...
+) where {T}
+    !bundle_resolve_swaps && return ts
+
+    # Resolve the swaps.
+    sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
+    ts_actual = map(ts) do t
+        composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
+    end
+
     AbstractMCMC.bundle_samples(
-        ts, maybe_wrap_model(model), sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type;
+        ts_actual, model, sampler.sampler_outer, state.state_outer, T;
+        kwargs...
+    )
+end
+
+# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
+function AbstractMCMC.bundle_samples(
+    ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
+    model::AbstractMCMC.AbstractModel,
+    sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
+    state::CompositionState{<:MultipleStates,<:SwapState},
+    ::Type{Vector{T}};
+    bundle_resolve_swaps::Bool=false,
+    kwargs...
+) where {T}
+    !bundle_resolve_swaps && return ts
+
+    # Resolve the swaps (using the already implemented resolution in `composition_transition`
+    # for this particular sampler but without `saveall`).
+    sampler_without_saveall = @set sampler.saveall = Val(false)
+    ts_actual = map(ts) do t
+        composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
+    end
+
+    return AbstractMCMC.bundle_samples(
+        ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
+        kwargs...
+    )
+end
+
+function AbstractMCMC.bundle_samples(
+    ts::AbstractVector,
+    model::AbstractMCMC.AbstractModel,
+    sampler::RepeatedSampler,
+    state,
+    ::Type{MCMCChains.Chains};
+    kwargs...
+)
+    return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...)
+end
+
+# Unflatten in the case of `SequentialTransitions`.
+function AbstractMCMC.bundle_samples(
+    ts::AbstractVector{<:SequentialTransitions},
+    model::AbstractMCMC.AbstractModel,
+    sampler::RepeatedSampler,
+    state::SequentialStates,
+    ::Type{MCMCChains.Chains};
+    kwargs...
+)
+    ts_actual = [t for tseq in ts for t in tseq.transitions]
+    return AbstractMCMC.bundle_samples(
+        ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains;
         kwargs...
     )
 end
diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl
new file mode 100644
index 0000000..716bb2c
--- /dev/null
+++ b/src/abstractmcmc.jl
@@ -0,0 +1,106 @@
+using Setfield
+using AbstractMCMC: AbstractMCMC
+
+import LinearAlgebra: ×
+
+"""
+    getparams([model, ]state)
+
+Get the parameters from the `state`.
+
+Default implementation uses [`getparams_and_logprob`](@ref).
+"""
+getparams(state) = first(getparams_and_logprob(state))
+getparams(model, state) = first(getparams_and_logprob(model, state))
+
+"""
+    getlogprob([model, ]state)
+
+Get the log probability of the `state`.
+
+Default implementation uses [`getparams_and_logprob`](@ref).
+"""
+getlogprob(state) = last(getparams_and_logprob(state))
+getlogprob(model, state) = last(getparams_and_logprob(model, state))
+
+"""
+    getparams_and_logprob([model, ]state)
+
+Return a vector of parameters from the `state`.
+
+See also: [`setparams_and_logprob!!`](@ref).
+"""
+getparams_and_logprob(model, state) = getparams_and_logprob(state)
+
+"""
+    setparams_and_logprob!!([model, ]state, params)
+
+Set the parameters in the state to `params`, possibly mutating if it makes sense.
+
+See also: [`getparams_and_logprob`](@ref).
+"""
+setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob)
+
+"""
+    state_from(model, state_target, state_source[, transition_source, transition_target])
+
+Return a new state similar to `state_target` but updated from `state_source`, which could be
+a different type of state.
+"""
+function state_from(model, state_target, state_source, transition_target, transition_source)
+    return state_from(model, state_target, state_source)
+end
+function state_from(model, state_target, state_source)
+    params, logp = getparams_and_logprob(model, state_source)
+    return setparams_and_logprob!!(model, state_target, params, logp)
+end
+
+"""
+    SequentialTransitions
+
+A `SequentialTransitions` object is a container for a sequence of transitions.
+"""
+struct SequentialTransitions{A}
+    transitions::A
+end
+
+# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
+# last transition/state.
+getparams_and_logprob(transitions::SequentialTransitions) = getparams_and_logprob(transitions.transitions[end])
+function getparams_and_logprob(model, transitions::SequentialTransitions)
+    return getparams_and_logprob(model, transitions.transitions[end])
+end
+
+function setparams_and_logprob!!(transitions::SequentialTransitions, params, logprob)
+    return @set transitions.transitions[end] = setparams_and_logprob!!(transitions.transitions[end], params, logprob)
+end
+function setparams_and_logprob!!(model, transitions::SequentialTransitions, params, logprob)
+    return @set transitions.transitions[end] = setparams_and_logprob!!(model, transitions.transitions[end], params, logprob)
+end
+
+"""
+    SequentialStates
+
+A `SequentialStates` object is a container for a sequence of states.
+"""
+struct SequentialStates{A}
+    states::A
+end
+
+# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
+# last transition/state.
+getparams_and_logprob(state::SequentialStates) = getparams_and_logprob(state.states[end])
+getparams_and_logprob(model, state::SequentialStates) = getparams_and_logprob(model, state.states[end])
+
+function setparams_and_logprob!!(state::SequentialStates, params, logprob)
+    return @set state.states[end] = setparams_and_logprob!!(state.states[end], params, logprob)
+end
+function setparams_and_logprob!!(model, state::SequentialStates, params, logprob)
+    return @set state.states[end] = setparams_and_logprob!!(model, state.states[end], params, logprob)
+end
+
+# Includes.
+include("samplers/composition.jl")
+include("samplers/repeated.jl")
+include("samplers/multi.jl")
+
diff --git a/src/adaptation.jl b/src/adaptation.jl
index 096ca13..25d640b 100644
--- a/src/adaptation.jl
+++ b/src/adaptation.jl
@@ -19,7 +19,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and
 """
 struct Geometric end
 
-defaultscale(::Geometric, Δ) = eltype(Δ)(0.9)
+defaultscale(::Geometric, Δ) = float(eltype(Δ))(0.9)
 
 """
     InverselyAdditive
diff --git a/src/ladders.jl b/src/ladders.jl
index 0ebf615..28305d1 100644
--- a/src/ladders.jl
+++ b/src/ladders.jl
@@ -37,9 +37,7 @@ end
 Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
 """
 function check_inverse_temperatures(Δ)
-    if length(Δ) <= 1
-        error("More than one inverse temperatures must be provided.")
-    end
+    !isempty(Δ) || error("Inverse temperatures array is empty.")
     if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
         error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].")
     end
diff --git a/src/sampler.jl b/src/sampler.jl
index 08f157c..7ecd097 100644
--- a/src/sampler.jl
+++ b/src/sampler.jl
@@ -1,3 +1,20 @@
+"""
+    TemperedState
+
+A state for a tempered sampler.
+
+# Fields
+$(FIELDS)
+"""
+@concrete struct TemperedState
+    "state for swap-sampler"
+    swapstate
+    "state for the main sampler"
+    state
+    "inverse temperature for each of the chains"
+    chain_to_beta
+end
+
 """
     TemperedSampler <: AbstractMCMC.AbstractSampler
 
@@ -7,44 +24,39 @@ A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Temp
 
 $(FIELDS)
 """
-@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler
+Base.@kwdef struct TemperedSampler{SplT,A,SwapT,Adapt} <: AbstractMCMC.AbstractSampler
     "sampler(s) used to target the tempered distributions"
-    sampler
+    sampler::SplT
     "collection of inverse temperatures β; β[i] correponds i-th tempered model"
-    inverse_temperatures
-    "number of steps of `sampler` to take before proposing swaps"
-    swap_every
-    "the swap strategy that will be used when proposing swaps"
-    swap_strategy
-    # TODO: This should be replaced with `P` just being some `NoAdapt` type.
+    chain_to_beta::A
+    "strategy to use for swapping"
+    swapstrategy::SwapT=ReversibleSwap()
+    # TODO: Remove `adapt` and just consider `adaptation_states=nothing` as no adaptation.
     "boolean flag specifying whether or not to adapt"
-    adapt
+    adapt=false
     "adaptation parameters"
-    adaptation_states
+    adaptation_states::Adapt=nothing
 end
 
-swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy
+TemperedSampler(sampler, chain_to_beta; kwargs...) = TemperedSampler(; sampler, chain_to_beta, kwargs...)
+
+swapsampler(sampler::TemperedSampler) = SwapSampler(sampler.swapstrategy)
 
+# TODO: Do we need this now?
 getsampler(samplers, I...) = getindex(samplers, I...)
 getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler
 getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...)
 
-"""
-    numsteps(sampler::TemperedSampler)
-
-Return number of inverse temperatures used by `sampler`.
-"""
-numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures)
+chain_to_process(state::TemperedState, I...) = chain_to_process(state.swapstate, I...)
+process_to_chain(state::TemperedState, I...) = process_to_chain(state.swapstate, I...)
 
 """
-    sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...])
+    sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
 
 Return the sampler corresponding to the chain indexed by `I...`.
-If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned.
 """
-sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1)
 function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
-    return getsampler(sampler.sampler, chain_to_process(state, I...))
+    return sampler_for_process(sampler, state, chain_to_process(state, I...))
 end
 
 """
@@ -53,9 +65,51 @@ end
 Return the sampler corresponding to the process indexed by `I...`.
 """
 function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...)
-    return getsampler(sampler.sampler, I...)
+    return _sampler_for_process_temper(sampler.sampler, state, I...)
 end
 
+# If `sampler` is a `MultiSampler`, we assume it's ordered according to chains.
+_sampler_for_process_temper(sampler::MultiSampler, state, I...) = sampler.samplers[process_to_chain(state, I...)]
+# Otherwise, we just use the same sampler for everything.
+_sampler_for_process_temper(sampler, state, I...) = sampler
+
+# Defer extracting the corresponding state to the `swapstate`.
+state_for_process(state::TemperedState, I...) = state_for_process(state.swapstate, I...)
+
+# Here we make the model(s) using the temperatures.
+function model_for_process(sampler::TemperedSampler, model, state::TemperedState, I...)
+    return make_tempered_model(sampler, model, beta_for_process(state, I...))
+end
+
+"""
+    beta_for_chain(state[, I...])
+
+Return the β corresponding to the chain indexed by `I...`.
+If `I...` is not specified, the β corresponding to `β=1.0` will be returned.
+"""
+beta_for_chain(state::TemperedState) = beta_for_chain(state, 1)
+beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...)
+# NOTE: Array impl. is useful for testing.
+beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...] 
+
+"""
+    beta_for_process(state, I...)
+
+Return the β corresponding to the process indexed by `I...`.
+"""
+beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.swapstate.process_to_chain, I...)
+# NOTE: Array impl. is useful for testing.
+function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...)
+    return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...))
+end
+
+"""
+    numsteps(sampler::TemperedSampler)
+
+Return number of inverse temperatures used by `sampler`.
+"""
+numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta)
+
 """
     tempered(sampler, inverse_temperatures; kwargs...)
     OR
@@ -98,7 +152,8 @@ function tempered(
     sampler::AbstractMCMC.AbstractSampler,
     inverse_temperatures::Vector{<:Real};
     swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
-    swap_every::Integer=10,
+    # TODO: Change `swap_every` to something like `number_of_iterations_per_swap`.
+    steps_per_swap::Integer=1,
     adapt::Bool=false,
     adapt_target::Real=0.234,
     adapt_stepsize::Real=1,
@@ -108,10 +163,13 @@ function tempered(
     kwargs...
 )
     !(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
-    swap_every > 1 || error("`swap_every` must take a positive integer value greater than 1.")
+    steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.")
     inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
     adaptation_states = init_adaptation(
         adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
     )
-    return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states)
+    # NOTE: We just make a repeated sampler for `sampler_inner`.
+    # TODO: Generalize. Allow passing in a `MultiSampler`, etc.
+    sampler_inner = sampler^steps_per_swap
+    return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states)
 end
diff --git a/src/samplers/composition.jl b/src/samplers/composition.jl
new file mode 100644
index 0000000..b8fb153
--- /dev/null
+++ b/src/samplers/composition.jl
@@ -0,0 +1,130 @@
+"""
+    CompositionSampler <: AbstractMCMC.AbstractSampler
+
+A `CompositionSampler` is a container for a sequence of samplers.
+
+# Fields
+$(FIELDS)
+
+# Examples
+```julia
+composed_sampler = sampler_inner ∘ sampler_outer # or `CompositionSampler(sampler_inner, sampler_outer, Val(true))`
+AbstractMCMC.step(rng, model, composed_sampler) # one step of `sampler_inner`, and one step of `sampler_outer`
+```
+"""
+struct CompositionSampler{S1,S2,SaveAll} <: AbstractMCMC.AbstractSampler
+    "The outer sampler"
+    sampler_outer::S1
+    "The inner sampler"
+    sampler_inner::S2
+    "Whether to save all the transitions or just the last one"
+    saveall::SaveAll
+end
+
+CompositionSampler(sampler_outer, sampler_inner) = CompositionSampler(sampler_outer, sampler_inner, Val(true))
+
+Base.:∘(s_outer::AbstractMCMC.AbstractSampler, s_inner::AbstractMCMC.AbstractSampler) = CompositionSampler(s_outer, s_inner)
+
+"""
+    saveall(sampler)
+
+Return whether the sampler saves all the transitions or just the last one.
+"""
+saveall(sampler::CompositionSampler) = sampler.saveall
+saveall(::CompositionSampler{<:Any,<:Any,Val{SaveAll}}) where {SaveAll} = SaveAll
+
+"""
+    CompositionState
+
+A `CompositionState` is a container for a sequence of states.
+
+# Fields
+$(FIELDS)
+"""
+struct CompositionState{S1,S2}
+    "The outer state"
+    state_outer::S1
+    "The inner state"
+    state_inner::S2
+end
+
+getparams_and_logprob(state::CompositionState) = getparams_and_logprob(state.state_outer)
+getparams_and_logprob(model, state::CompositionState) = getparams_and_logprob(model, state.state_outer)
+
+function setparams_and_logprob!!(state::CompositionState, params, logprob)
+    return @set state.state_outer = setparams_and_logprob!!(state.state_outer, params, logprob)
+end
+function setparams_and_logprob!!(model, state::CompositionState, params, logprob)
+    return @set state.state_outer = setparams_and_logprob!!(model, state.state_outer, params, logprob)
+end
+
+struct CompositionTransition{S1,S2}
+    "The outer transition"
+    transition_outer::S1
+    "The inner transition"
+    transition_inner::S2
+end
+
+# Useful functions for interacting with composition sampler and states.
+inner_sampler(sampler::CompositionSampler) = sampler.sampler_inner
+outer_sampler(sampler::CompositionSampler) = sampler.sampler_outer
+
+inner_state(state::CompositionState) = state.state_inner
+outer_state(state::CompositionState) = state.state_outer
+
+inner_transition(transition::CompositionTransition) = transition.transition_inner
+outer_transition(transition::CompositionTransition) = transition.transition_outer
+outer_transition(transition) = transition  # in case we don't have `saveall`
+
+# TODO: We really don't need to use `SequentialStates` here, do we?
+composition_state(sampler, state_inner, state_outer) = CompositionState(state_outer, state_inner)
+function composition_transition(sampler, transition_inner, transition_outer)
+    return if saveall(sampler)
+        CompositionTransition(transition_outer, transition_inner)
+    else
+        transition_outer
+    end
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::AbstractMCMC.AbstractModel,
+    sampler::CompositionSampler;
+    kwargs...
+)
+    state_inner_initial = last(AbstractMCMC.step(rng, model, inner_sampler(sampler); kwargs...))
+    state_outer_initial = last(AbstractMCMC.step(rng, model, outer_sampler(sampler); kwargs...))
+
+    # Create the composition state, and take a full step.
+    state = composition_state(sampler, state_inner_initial, state_outer_initial)
+    return AbstractMCMC.step(rng, model, sampler, state; kwargs...)
+end
+
+# TODO: Do we even need two versions? We could technically use `SequentialStates`
+# in place of `CompositionState` and just have one version.
+# The annoying part here is that we'll have to check `saveall` on every `step`
+# rather than just for the initial step.
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::AbstractMCMC.AbstractModel,
+    sampler::CompositionSampler,
+    state;
+    kwargs...
+)
+    state_inner_prev, state_outer_prev = inner_state(state), outer_state(state)
+
+    # Update the inner state.
+    current_state_inner = state_from(model, state_inner_prev, state_outer_prev)
+
+    # Take a step in the inner sampler.
+    transition_inner, state_inner = AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...)
+
+    # Take a step in the outer sampler.
+    current_state_outer = state_from(model, state_outer_prev, state_inner)
+    transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...)
+
+    return (
+        composition_transition(sampler, transition_inner, transition_outer),
+        composition_state(sampler, state_inner, state_outer)
+    )
+end
diff --git a/src/samplers/multi.jl b/src/samplers/multi.jl
new file mode 100644
index 0000000..f007e00
--- /dev/null
+++ b/src/samplers/multi.jl
@@ -0,0 +1,210 @@
+# Multiple independent samplers.
+combine(x::Tuple, y::Tuple) = (x..., y...)
+combine(x::Tuple, y) = (x..., y)
+combine(x, y::Tuple) = (x, y...)
+combine(x::AbstractArray, y::AbstractArray) = vcat(x, y)
+combine(x::AbstractArray, y) = vcat(x, y)
+combine(x, y::AbstractArray) = vcat(x, y)
+combine(x, y) = Iterators.flatten((x, y))
+
+
+"""
+    MultiSampler <: AbstractMCMC.AbstractSampler
+
+A `MultiSampler` is a container for multiple samplers.
+
+See also: [`MultiModel`](@ref).
+
+# Fields
+$(FIELDS)
+
+# Examples
+```julia
+# `sampler1` targets `model1`, `sampler2` targets `model2`, etc.
+multi_model = model1 × model2 × model3 # or `MultiModel((model1, model2, model3))`
+multi_sampler = sampler1 × sampler2 × sampler3 # or `MultiSampler((sampler1, sampler2, sampler3))`
+# Target the joint model.
+AbstractMCMC.step(rng, multi_model, multi_sampler)
+```
+"""
+struct MultiSampler{S} <: AbstractMCMC.AbstractSampler
+    "The samplers"
+    samplers::S
+end
+
+×(sampler1::AbstractMCMC.AbstractSampler, sampler2::AbstractMCMC.AbstractSampler) = MultiSampler((sampler1, sampler2))
+×(sampler1::MultiSampler, sampler2::AbstractMCMC.AbstractSampler) = MultiSampler(combine(sampler1.samplers, sampler2))
+×(sampler1::AbstractMCMC.AbstractSampler, sampler2::MultiSampler) = MultiSampler(combine(sampler1, sampler2.samplers))
+×(sampler1::MultiSampler, sampler2::MultiSampler) = MultiSampler(combine(sampler1.samplers, sampler2.samplers))
+
+"""
+    MultiModel <: AbstractMCMC.AbstractModel
+
+A `MultiModel` is a container for multiple models.
+
+See also: [`MultiSampler`](@ref).
+
+# Fields
+$(FIELDS)
+"""
+struct MultiModel{M} <: AbstractMCMC.AbstractModel
+    "The models"
+    models::M
+end
+
+×(model1::AbstractMCMC.AbstractModel, model2::AbstractMCMC.AbstractModel) = MultiModel((model1, model2))
+×(model1::MultiModel, model2::AbstractMCMC.AbstractModel) = MultiModel(combine(model1.models, model2))
+×(model1::AbstractMCMC.AbstractModel, model2::MultiModel) = MultiModel(combine(model1, model2.models))
+×(model1::MultiModel, model2::MultiModel) = MultiModel(combine(model1.models, model2.models))
+
+Base.length(model::MultiModel) = length(model.models)
+
+# TODO: Make these subtypes of `AbstractVector`?
+"""
+    MultipleTransitions
+
+A container for multiple transitions.
+
+See also: [`MultipleStates`](@ref).
+
+# Fields
+$(FIELDS)
+"""
+struct MultipleTransitions{A}
+    "The transitions"
+    transitions::A
+end
+
+function getparams_and_logprob(transitions::MultipleTransitions)
+    params_and_logprobs = map(getparams_and_logprob, transitions.transitions)
+    return map(first, params_and_logprobs), map(last, params_and_logprobs)
+end
+function getparams_and_logprob(model::MultiModel, transitions::MultipleTransitions)
+    params_and_logprobs = map(getparams_and_logprob, model.models, transitions.transitions)
+    return map(first, params_and_logprobs), map(last, params_and_logprobs)
+end
+
+"""
+    MultipleStates
+
+A container for multiple states.
+
+See also: [`MultipleTransitions`](@ref).
+
+# Fields
+$(FIELDS)
+"""
+struct MultipleStates{A}
+    "The states"
+    states::A
+end
+
+# NOTE: This is different from most of the other implementations of `getparams_and_logprob`
+# as here we need to work with multiple models, transitions, and states.
+function getparams_and_logprob(state::MultipleStates)
+    params_and_logprobs = map(getparams_and_logprob, state.states)
+    return map(first, params_and_logprobs), map(last, params_and_logprobs)
+end
+function getparams_and_logprob(model::MultiModel, state::MultipleStates)
+    params_and_logprobs = map(getparams_and_logprob, model.models, state.states)
+    return map(first, params_and_logprobs), map(last, params_and_logprobs)
+end
+
+function setparams_and_logprob!!(state::MultipleStates, params, logprobs)
+    @assert length(params) == length(logprobs) == length(state.states) "The number of parameters and log probabilities must match the number of states."
+    return @set state.states = map(setparams_and_logprob!!, state.states, params, logprobs)
+end
+function setparams_and_logprob!!(model::MultiModel, state::MultipleStates, params, logprobs)
+    @assert length(model.models) == length(params) == length(logprobs) == length(state.states) "The number of models, states, parameters, and log probabilities must match."
+    return @set state.states = map(setparams_and_logprob!!, model.models, state.states, params, logprobs)
+end
+
+# TODO: Clean this up.
+initparams(model::MultiModel, init_params) = map(Base.Fix1(get_init_params, init_params), 1:length(model.models))
+initparams(model::MultiModel{<:Tuple}, init_params) =  ntuple(length(model.models)) do i
+    get_init_params(init_params, i)
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    sampler::MultiSampler;
+    init_params=nothing,
+    kwargs...
+)
+    @assert length(model.models) == length(sampler.samplers) "Number of models and samplers must be equal"
+
+    # TODO: Handle `init_params` properly. Make sure that they respect the container-types used in
+    # `MultiModel` and `MultiSampler`.
+    init_params_multi = initparams(model, init_params)
+    transition_and_states = asyncmap(model.models, sampler.samplers, init_params_multi) do model, sampler, init_params
+        AbstractMCMC.step(rng, model, sampler; init_params, kwargs...)
+    end
+
+    return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    sampler::MultiSampler,
+    states::MultipleStates;
+    kwargs...
+)
+    @assert length(model.models) == length(sampler.samplers) == length(states.states) "Number of models, samplers, and states must be equal."
+
+    transition_and_states = asyncmap(model.models, sampler.samplers, states.states) do model, sampler, state
+        AbstractMCMC.step(rng, model, sampler, state; kwargs...)
+    end
+
+    return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
+end
+
+# NOTE: In the case of a `RepeatedSampler{<:MultiSampler}`, it's better to, effectively, re-order
+# the samplers so that we make a `MultiSampler` of `RepeatedSampler`s.
+# We don't want to mutate the sampler, so instead we just convert the sequence of multi-states into
+# a multi-state of sequential states, and then work with this ordering in subsequent calls to `step`.
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    repeated_sampler::RepeatedSampler{<:MultiSampler},
+    states::SequentialStates;
+    kwargs...
+)
+    @debug "Working with RepeatedSampler{<:MultiSampler}; converting a sequence of multi-states into a multi-state of sequential states"
+
+    multisampler = repeated_sampler.sampler
+    multistates = last(states.states)
+    @assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal."
+    transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state
+        # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation.
+        AbstractMCMC.step(
+            rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), SequentialStates([state]);
+            kwargs...
+        )
+    end
+
+    return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
+end
+
+# And then we define how `RepeatedSampler{<:MultiSampler}` should work with a `MultipleStates`.
+# NOTE: If `saveall(sampler)` is `false`, this is also the implementation we'll hit.
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    repeated_sampler::RepeatedSampler{<:MultiSampler},
+    multistates::MultipleStates;
+    kwargs...
+)
+    multisampler = repeated_sampler.sampler
+    @assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal."
+    transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state
+        # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation.
+        AbstractMCMC.step(
+            rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), state;
+            kwargs...
+        )
+    end
+
+    return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
+end
diff --git a/src/samplers/repeated.jl b/src/samplers/repeated.jl
new file mode 100644
index 0000000..8ffa232
--- /dev/null
+++ b/src/samplers/repeated.jl
@@ -0,0 +1,81 @@
+"""
+    RepeatedSampler <: AbstractMCMC.AbstractSampler
+
+A `RepeatedSampler` is a container for a sampler and a number of times to repeat it.
+
+# Fields
+$(FIELDS)
+
+# Examples
+```julia
+repeated_sampler = sampler^10 # or `RepeatedSampler(sampler, 10, Val(true))`
+AbstractMCMC.step(rng, model, repeated_sampler) # take 10 steps of `sampler`
+```
+"""
+struct RepeatedSampler{S,SaveAll} <: AbstractMCMC.AbstractSampler
+    "The sampler to repeat"
+    sampler::S
+    "The number of times to repeat the sampler"
+    num_repeat::Int
+    "Whether to save all the transitions or just the last one"
+    saveall::SaveAll
+end
+
+RepeatedSampler(sampler, num_repeat) = RepeatedSampler(sampler, num_repeat, Val(true))
+
+Base.@constprop :aggressive Base.:^(s::AbstractMCMC.AbstractSampler, n::Int) = RepeatedSampler(s, n, Val(true))
+
+saveall(sampler::RepeatedSampler) = sampler.saveall
+saveall(::RepeatedSampler{<:Any,Val{SaveAll}}) where {SaveAll} = SaveAll
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::AbstractMCMC.AbstractModel,
+    sampler::RepeatedSampler;
+    kwargs...
+)
+    state = last(AbstractMCMC.step(rng, model, sampler.sampler; kwargs...))
+    state_repeated = saveall(sampler) ? SequentialStates([state]) : state
+
+    return AbstractMCMC.step(rng, model, sampler, state_repeated; kwargs...)
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::AbstractMCMC.AbstractModel,
+    sampler::RepeatedSampler,
+    state;
+    kwargs...
+)
+    # Take a step in the inner sampler.
+    transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...)
+
+    # Take a step in the outer sampler.
+    for _ in 2:sampler.num_repeat
+        transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...)
+    end
+
+    return transition, state
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::AbstractMCMC.AbstractModel,
+    sampler::RepeatedSampler,
+    state::SequentialStates;
+    kwargs...
+)
+    # Take a step in the inner sampler.
+    transition, state_inner = AbstractMCMC.step(rng, model, sampler.sampler, state.states[end]; kwargs...)
+
+    # Take a step in the outer sampler.
+    transitions = [transition]
+    states = [state_inner]
+    for _ in 2:sampler.num_repeat
+        transition, state_inner = AbstractMCMC.step(rng, model, sampler.sampler, state_inner; kwargs...)
+        push!(transitions, transition)
+        push!(states, state_inner)
+    end
+
+    return SequentialTransitions(transitions), SequentialStates(states)
+end
diff --git a/src/state.jl b/src/state.jl
index ac5fc36..96e6c98 100644
--- a/src/state.jl
+++ b/src/state.jl
@@ -1,5 +1,19 @@
 """
-    TemperedState
+    ProcessOrder
+
+Specifies that the `model` should be treated as process-ordered.
+"""
+struct ProcessOrder end
+
+"""
+    ChainOrder
+
+Specifies that the `model` should be treated as chain-ordered.
+"""
+struct ChainOrder end
+
+"""
+    SwapState
 
 A general implementation of a state for a [`TemperedSampler`](@ref).
 
@@ -17,7 +31,7 @@ Moreover, suppose we also have 4 workers/processes for which we run these chains
 (can also be serial wlog).
 
 We can then perform a swap in two different ways:
-1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`.
+1. Swap the the _states_ between each process, i.e. permute `transitions` and `states`.
 2. Swap the _temperatures_ between each process, i.e. permute `chain_to_beta`.
 
 (1) is possibly the most intuitive approach since it means that the i-th worker/process
@@ -52,69 +66,87 @@ Chains:    process_to_chain    chain_to_process   inverse_temperatures[process_t
 In this case, the chain `X` can be reconstructed as:
 
 ```julia
-X[1] = states[1].transitions_and_states[1]
-X[2] = states[2].transitions_and_states[2]
-X[3] = states[3].transitions_and_states[2]
-X[4] = states[4].transitions_and_states[3]
-X[5] = states[5].transitions_and_states[3]
+X[1] = states[1].states[1]
+X[2] = states[2].states[2]
+X[3] = states[3].states[2]
+X[4] = states[4].states[3]
+X[5] = states[5].states[3]
 ```
 
+and similarly for the states.
+
 The indices here are exactly those represented by `states[k].chain_to_process[1]`.
 """
-@concrete struct TemperedState
-    "collection of `(transition, state)` pairs for each process"
-    transitions_and_states
-    "collection of (inverse) temperatures β corresponding to each chain"
-    chain_to_beta
+@concrete struct SwapState
+    "collection of states for each process"
+    states
     "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain"
     chain_to_process
     "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process"
     process_to_chain
     "total number of steps taken"
     total_steps
-    "number of burn-in steps taken"
-    burnin_steps
-    "contains all necessary information for adaptation of inverse_temperatures"
-    adaptation_states
-    "flag which specifies wether this was a swap-step or not"
-    is_swap
     "swap acceptance ratios on log-scale"
     swap_acceptance_ratios
 end
 
+# TODO: Can we support more?
+function SwapState(state::MultipleStates)
+    process_to_chain = collect(1:length(state.states))
+    chain_to_process = copy(process_to_chain)
+    return SwapState(state.states, chain_to_process, process_to_chain, 1, Dict{Int,Float64}())
+end
+
+# Defer these to `MultipleStates`.
+# TODO: What is the best way to implement these? Should we sort according to the chain indices
+# to match the order of the models?
+# getparams_and_logprob(state::SwapState) = getparams_and_logprob(MultipleStates(state.states))
+# getparams_and_logprob(model, state::SwapState) = getparams_and_logprob(model, MultipleStates(state.states))
+
+function setparams_and_logprob!!(model, state::SwapState, params, logprobs)
+    # Use the `MultipleStates`'s implementation to update the underlying states.
+    multistate = setparams_and_logprob!!(model, MultipleStates(state.states), params, logprobs)
+    # Update the states!
+    return @set state.states = multistate.states
+end
+
 """
-    process_to_chain(state, I...)
+    sort_by_chain(::ChainOrdering, state, xs)
+    sort_by_chain(::ProcessOrdering, state, xs)
 
-Return the chain index corresponding to the process index `I`.
+Return `xs` sorted according to the chain indices, as specified by `state`.
 """
-process_to_chain(state::TemperedState, I...) = process_to_chain(state.process_to_chain, I...)
-# NOTE: Array impl. is useful for testing.
-process_to_chain(proc2chain::AbstractArray, I...) = proc2chain[I...]
+sort_by_chain(::ChainOrder, ::Any, xs) = xs
+sort_by_chain(::ProcessOrder, state, xs) = [xs[chain_to_process(state, i)] for i = 1:length(xs)]
+sort_by_chain(::ProcessOrder, state, xs::Tuple) = ntuple(i -> xs[chain_to_process(state, i)], length(xs))
 
 """
-    chain_to_process(state, I...)
+    sort_by_process(::ProcessOrdering, state, xs)
+    sort_by_process(::ChainOrdering, state, xs)
 
-Return the process index corresponding to the chain index `I`.
+Return `xs` sorted according to the process indices, as specified by `state`.
 """
-chain_to_process(state::TemperedState, I...) = chain_to_process(state.chain_to_process, I...)
-# NOTE: Array impl. is useful for testing.
-chain_to_process(chain2proc::AbstractArray, I...) = chain2proc[I...]
+sort_by_process(::ProcessOrder, ::Any, xs) = xs
+sort_by_process(::ChainOrder, state, xs) = [xs[process_to_chain(state, i)] for i = 1:length(xs)]
+sort_by_process(::ChainOrder, state, xs::Tuple) = ntuple(i -> xs[process_to_chain(state, i)], length(xs))
 
 """
-    transition_for_chain(state[, I...])
+    process_to_chain(state, I...)
 
-Return the transition corresponding to the chain indexed by `I...`.
-If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`.
+Return the chain index corresponding to the process index `I`.
 """
-transition_for_chain(state::TemperedState) = transition_for_chain(state, 1)
-transition_for_chain(state::TemperedState, I...) = transition_for_process(state, chain_to_process(state, I...))
+process_to_chain(state::SwapState, I...) = process_to_chain(state.process_to_chain, I...)
+# NOTE: Array impl. is useful for testing.
+process_to_chain(proc2chain, I...) = proc2chain[I...]
 
 """
-    transition_for_process(state, I...)
+    chain_to_process(state, I...)
 
-Return the transition corresponding to the process indexed by `I...`.
+Return the process index corresponding to the chain index `I`.
 """
-transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1]
+chain_to_process(state::SwapState, I...) = chain_to_process(state.chain_to_process, I...)
+# NOTE: Array impl. is useful for testing.
+chain_to_process(chain2proc, I...) = chain2proc[I...]
 
 """
     state_for_chain(state[, I...])
@@ -122,47 +154,49 @@ transition_for_process(state::TemperedState, I...) = state.transitions_and_state
 Return the state corresponding to the chain indexed by `I...`.
 If `I...` is not specified, the state corresponding to `β=1.0` will be returned.
 """
-state_for_chain(state::TemperedState) = state_for_chain(state, 1)
-state_for_chain(state::TemperedState, I...) = state_for_process(state, chain_to_process(state, I...))
+state_for_chain(state) = state_for_chain(state, 1)
+state_for_chain(state, I...) = state_for_process(state, chain_to_process(state, I...))
 
 """
     state_for_process(state, I...)
 
 Return the state corresponding to the process indexed by `I...`.
 """
-state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2]
+state_for_process(state::SwapState, I...) = state_for_process(state.states, I...)
+state_for_process(proc2state, I...) = proc2state[I...]
 
 """
-    beta_for_chain(state[, I...])
+    model_for_chain(ordering, sampler, model, state, I...)
+
+Return the model corresponding to the chain indexed by `I...`.
 
-Return the β corresponding to the chain indexed by `I...`.
-If `I...` is not specified, the β corresponding to `β=1.0` will be returned.
+`ordering` specifies what sort of order the input models follow.
 """
-beta_for_chain(state::TemperedState) = beta_for_chain(state, 1)
-beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...)
-# NOTE: Array impl. is useful for testing.
-beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...]
+function model_for_chain end
 
 """
-    beta_for_process(state, I...)
+    model_for_process(ordering, sampler, model, state, I...)
+
+Return the model corresponding to the process indexed by `I...`.
 
-Return the β corresponding to the process indexed by `I...`.
+`ordering` specifies what sort of order the input models follow.
 """
-beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.process_to_chain, I...)
-# NOTE: Array impl. is useful for testing.
-function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...)
-    return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...))
-end
+function model_for_process end
 
 """
-    getparams(transition)
-    getparams(::Type, transition)
+    models_by_processes(ordering, models, state)
 
-Return the parameters contained in `transition`.
+Return the models in the order of processes, assuming `models` is sorted according to `ordering`.
+
+See also: [`ProcessOrdering`](@ref), [`ChainOrdering`](@ref).
+"""
+models_by_processes(ordering, models, state) = sort_by_process(ordering, state, models)
+
+"""
+    samplers_by_processes(ordering, samplers, state)
 
-If a type is specified, the parameters are returned in said type.
+Return the `samplers` in the order of processes, assuming `samplers` is sorted according to `ordering`.
 
-# Notes
-This method is meant to be overloaded for the different transitions types.
+See also: [`ProcessOrdering`](@ref), [`ChainOrdering`](@ref).
 """
-function getparams end
+samplers_by_processes(ordering, samplers, state) = sort_by_process(ordering, state, samplers)
diff --git a/src/stepping.jl b/src/stepping.jl
index 83d35fa..3ea2ce1 100644
--- a/src/stepping.jl
+++ b/src/stepping.jl
@@ -1,134 +1,88 @@
-"""
-    should_swap(sampler, state)
-
-Return `true` if a swap should happen at this iteration, and `false` otherwise.
-"""
-function should_swap(sampler::TemperedSampler, state::TemperedState)
-    return state.total_steps % sampler.swap_every == 1
-end
-
-get_init_params(x, _)= x
+get_init_params(x, _) = x
 get_init_params(init_params::Nothing, _) = nothing
 get_init_params(init_params::AbstractVector{<:Real}, _) = copy(init_params)
 get_init_params(init_params::AbstractVector{<:AbstractVector{<:Real}}, i) = init_params[i]
 
+@concrete struct TemperedTransition
+    swaptransition
+    transition
+end
+
+function transition_for_chain(transition::TemperedTransition, I...)
+    chain_idx = transition.swaptransition.chain_to_process[I...]
+    return transition.transition.transitions[chain_idx]
+end
+
 function AbstractMCMC.step(
     rng::Random.AbstractRNG,
-    model,
+    model::AbstractMCMC.AbstractModel,
     sampler::TemperedSampler;
-    N_burnin::Integer=0,
-    burnin_progress::Bool=AbstractMCMC.PROGRESS[],
-    init_params=nothing,
     kwargs...
 )
-
-    # `TemperedState` has the transitions and states in the order of
-    # the processes, and performs swaps by moving the (inverse) temperatures
-    # `β` between the processes, rather than moving states between processes
-    # and keeping the `β` local to each process.
-    # 
-    # Therefore we iterate over the processes and then extract the corresponding
-    # `β`, `sampler` and `state`, and take a initialize.
-    transitions_and_states = [
-        AbstractMCMC.step(
-            rng,
-            make_tempered_model(sampler, model, sampler.inverse_temperatures[i]),
-            getsampler(sampler, i);
-            init_params=get_init_params(init_params, i),
-            kwargs...
-        )
+    # Create a `MultiSampler` and `MultiModel`.
+    multimodel = MultiModel([
+        make_tempered_model(sampler, model, sampler.chain_to_beta[i])
         for i in 1:numtemps(sampler)
-    ]
+    ])
+    multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
+    multistate = last(AbstractMCMC.step(rng, multimodel, multisampler; kwargs...))
 
     # Make sure to collect, because we'll be using `setindex!(!)` later.
-    process_to_chain = collect(1:length(sampler.inverse_temperatures))
+    process_to_chain = collect(1:length(sampler.chain_to_beta))
     # Need to `copy` because this might be mutated.
     chain_to_process = copy(process_to_chain)
-    state = TemperedState(
-        transitions_and_states,
-        sampler.inverse_temperatures,
-        process_to_chain,
+    swapstate = SwapState(
+        multistate.states,
         chain_to_process,
+        process_to_chain,
         1,
-        0,
-        sampler.adaptation_states,
-        false,
-        Dict{Int,Float64}()
+        Dict{Int,Float64}(),
     )
 
-    if N_burnin > 0
-        AbstractMCMC.@ifwithprogresslogger burnin_progress name = "Burn-in" begin
-            # Determine threshold values for progress logging
-            # (one update per 0.5% of progress)
-            if burnin_progress
-                threshold = N_burnin ÷ 200
-                next_update = threshold
-            end
-
-            for i in 1:N_burnin
-                if burnin_progress && i >= next_update
-                    ProgressLogging.@logprogress i / N_burnin
-                    next_update = i + threshold
-                end
-                state = no_swap_step(rng, model, sampler, state; kwargs...)
-                @set! state.burnin_steps += 1
-            end
-        end
-    end
-
-    return transition_for_chain(state), state
+    return AbstractMCMC.step(rng, model, sampler, TemperedState(swapstate, multistate, sampler.chain_to_beta))
 end
 
 function AbstractMCMC.step(
     rng::Random.AbstractRNG,
-    model,
+    model::AbstractMCMC.AbstractModel,
     sampler::TemperedSampler,
     state::TemperedState;
     kwargs...
 )
-    # Reset state
-    @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios)
-
-    if should_swap(sampler, state)
-        state = swap_step(rng, model, sampler, state)
-        @set! state.is_swap = true
-    else
-        state = no_swap_step(rng, model, sampler, state; kwargs...)
-        @set! state.is_swap = false
-    end
+    # Create the tempered `MultiModel`.
+    multimodel = MultiModel([make_tempered_model(sampler, model, beta) for beta in state.chain_to_beta])
+    # Create the tempered `MultiSampler`.
+    # We're assuming the user has given the samplers in an order according to the initial models.
+    multisampler = MultiSampler(samplers_by_processes(
+        ChainOrder(),
+        [getsampler(sampler, i) for i in 1:numtemps(sampler)],
+        state.swapstate
+    ))
+    # Create the composition which applies `SwapSampler` first.
+    sampler_composition = multisampler ∘ swapsampler(sampler)
+
+    # Step!
+    # NOTE: This will internally re-order the models according to processes before taking steps,
+    # hence the resulting transitions and states will be in the order of processes, as we desire.
+    transition_composition, state_composition = AbstractMCMC.step(
+        rng,
+        multimodel,
+        sampler_composition,
+        composition_state(sampler_composition, state.swapstate, state.state);
+        kwargs...
+    )
 
-    @set! state.total_steps += 1
+    # Construct the `TemperedTransition` and `TemperedState`.
+    swaptransition = inner_transition(transition_composition)
+    outertransition = outer_transition(transition_composition)
 
-    # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`.
-    return transition_for_chain(state), state
-end
+    swapstate = inner_state(state_composition)
+    outerstate = outer_state(state_composition)
 
-function no_swap_step(
-    rng::Random.AbstractRNG,
-    model,
-    sampler::TemperedSampler,
-    state::TemperedState;
-    kwargs...
-)
-    # `TemperedState` has the transitions and states in the order of
-    # the processes, and performs swaps by moving the (inverse) temperatures
-    # `β` between the processes, rather than moving states between processes
-    # and keeping the `β` local to each process.
-    # 
-    # Therefore we iterate over the processes and then extract the corresponding
-    # `β`, `sampler` and `state`, and take a step.
-    @set! state.transitions_and_states = [
-        AbstractMCMC.step(
-            rng,
-            make_tempered_model(sampler, model, beta_for_process(state, i)),
-            sampler_for_process(sampler, state, i),
-            state_for_process(state, i);
-            kwargs...
-        )
-        for i in 1:numtemps(sampler)
-    ]
-
-    return state
+    return (
+        TemperedTransition(swaptransition, outertransition),
+        TemperedState(swapstate, outerstate, state.chain_to_beta)
+    )
 end
 
 """
@@ -142,8 +96,8 @@ is used.
 function swap_step(
     rng::Random.AbstractRNG,
     model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    sampler,
+    state
 )
     return swap_step(swapstrategy(sampler), rng, model, sampler, state)
 end
@@ -151,31 +105,34 @@ end
 function swap_step(
     strategy::ReversibleSwap,
     rng::Random.AbstractRNG,
-    model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    model::MultiModel,
+    sampler,
+    state
 )
     # Randomly select whether to attempt swaps between chains
     # corresponding to odd or even indices of the temperature ladder
-    odd = rand([true, false])
-    for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))]
-        state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt)
+    odd = rand(rng, Bool)
+    # TODO: Use integer-division.
+    for k in [Int(2 * i - odd) for i in 1:(floor((length(model) - 1 + odd) / 2))]
+        state = swap_attempt(rng, model, sampler, state, k, k + 1)
     end
     return state
 end
 
+
 function swap_step(
     strategy::NonReversibleSwap,
     rng::Random.AbstractRNG,
-    model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    model::MultiModel,
+    sampler,
+    state::SwapState  # we're accessing `total_steps` restrict the type here
 )
     # Alternate between attempting to swap chains corresponding
     # to odd and even indices of the temperature ladder
-    odd = state.total_steps % (2 * sampler.swap_every) != 0
-    for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))]
-        state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt)
+    odd = state.total_steps % 2 != 0
+    # TODO: Use integer-division.
+    for k in [Int(2 * i - odd) for i in 1:(floor((length(model) - 1 + odd) / 2))]
+        state = swap_attempt(rng, model, sampler, state, k, k + 1)
     end
     return state
 end
@@ -183,45 +140,45 @@ end
 function swap_step(
     strategy::SingleSwap,
     rng::Random.AbstractRNG,
-    model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    model::MultiModel,
+    sampler,
+    state
 )
     # Randomly pick one index `k` of the temperature ladder and
     # attempt a swap between the corresponding chain and its neighbour
-    k = rand(rng, 1:(numtemps(sampler) - 1))
-    return swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt)
+    k = rand(rng, 1:(length(model) - 1))
+    return swap_attempt(rng, model, sampler, state, k, k + 1)
 end
 
 function swap_step(
     strategy::SingleRandomSwap,
     rng::Random.AbstractRNG,
-    model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    model::MultiModel,
+    sampler,
+    state
 )
     # Randomly pick two temperature ladder indices in order to
     # attempt a swap between the corresponding chains
-    chains = Set(1:numtemps(sampler))
+    chains = Set(1:length(model))
     i = pop!(chains, rand(rng, chains))
     j = pop!(chains, rand(rng, chains))
-    return swap_attempt(rng, model, sampler, state, i, j, sampler.adapt)
+    return swap_attempt(rng, model, sampler, state, i, j)
 end
 
 function swap_step(
     strategy::RandomSwap,
     rng::Random.AbstractRNG,
-    model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    model::MultiModel,
+    sampler,
+    state
 )
     # Iterate through all of temperature ladder indices, picking random
     # pairs and attempting swaps between the corresponding chains
-    chains = Set(1:numtemps(sampler))
+    chains = Set(1:length(model))
     while length(chains) >= 2
         i = pop!(chains, rand(rng, chains))
         j = pop!(chains, rand(rng, chains))
-        state = swap_attempt(rng, model, sampler, state, i, j, sampler.adapt)
+        state = swap_attempt(rng, model, sampler, state, i, j)
     end
     return state
 end
@@ -230,8 +187,8 @@ function swap_step(
     strategy::NoSwap,
     rng::Random.AbstractRNG,
     model,
-    sampler::TemperedSampler,
-    state::TemperedState
+    sampler,
+    state
 )
     return state
-end
\ No newline at end of file
+end
diff --git a/src/swapping.jl b/src/swapping.jl
index 578aa9c..6bd8eae 100644
--- a/src/swapping.jl
+++ b/src/swapping.jl
@@ -79,11 +79,11 @@ this overrides and disables all swapping functionality.
 struct NoSwap <: AbstractSwapStrategy end
 
 """
-    swap_betas!(chain_to_process, process_to_chain, i, j)
+    swap!(chain_to_process, process_to_chain, i, j)
 
 Swaps the `i`th and `j`th temperatures in place.
 """
-function swap_betas!(chain_to_process, process_to_chain, i, j)
+function swap!(chain_to_process, process_to_chain, i, j)
     # TODO: Use BangBang's `@set!!` to also support tuples?
     # Extract the process index for each of the chains.
     process_for_chain_i, process_for_chain_j = chain_to_process[i], chain_to_process[j]
@@ -113,8 +113,8 @@ and calls [`logdensity`](@ref) on the model returned from [`make_tempered_model`
 function compute_tempered_logdensities(model, sampler, transition, transition_other, β)
     tempered_model = make_tempered_model(sampler, model, β)
     return (
-        logdensity(tempered_model, getparams(transition)),
-        logdensity(tempered_model, getparams(transition_other))
+        logdensity(tempered_model, getparams(tempered_model, transition)),
+        logdensity(tempered_model, getparams(tempered_model, transition_other))
     )
 end
 function compute_tempered_logdensities(
@@ -123,6 +123,19 @@ function compute_tempered_logdensities(
     return compute_tempered_logdensities(model, sampler, transition, transition_other, β)
 end
 
+function compute_logdensities(
+    model::AbstractMCMC.AbstractModel,
+    model_other::AbstractMCMC.AbstractModel,
+    state,
+    state_other,
+)
+    # TODO: Make use of `getparams_and_logprob` instead?
+    return (
+        logdensity(model, getparams(model, state)),
+        logdensity(model, getparams(model_other, state_other))
+    )
+end
+
 """
     swap_acceptance_pt(logπi, logπj)
 
@@ -134,50 +147,3 @@ function swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj)
     return (logπjθi + logπiθj) - (logπiθi + logπjθj)
 end
 
-
-"""
-    swap_attempt(rng, model, sampler, state, i, j)
-
-Attempt to swap the temperatures of two chains by tempering the densities and
-calculating the swap acceptance ratio; then swapping if it is accepted.
-"""
-function swap_attempt(rng, model, sampler, state, i, j, adapt)
-    # Extract the relevant transitions.
-    sampler_i = sampler_for_chain(sampler, state, i)
-    sampler_j = sampler_for_chain(sampler, state, j)
-    transition_i = transition_for_chain(state, i)
-    transition_j = transition_for_chain(state, j)
-    state_i = state_for_chain(state, i)
-    state_j = state_for_chain(state, j)
-    β_i = beta_for_chain(state, i)
-    β_j = beta_for_chain(state, j)
-    # Evaluate logdensity for both parameters for each tempered density.
-    logπiθi, logπiθj = compute_tempered_logdensities(
-        model, sampler_i, sampler_j, transition_i, transition_j, state_i, state_j, β_i, β_j
-    )
-    logπjθj, logπjθi = compute_tempered_logdensities(
-        model, sampler_j, sampler_i, transition_j, transition_i, state_j, state_i, β_j, β_i
-    )
-    
-    # If the proposed temperature swap is accepted according `logα`,
-    # swap the temperatures for future steps.
-    logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj)
-    should_swap = -Random.randexp(rng) ≤ logα
-    if should_swap
-        swap_betas!(state.chain_to_process, state.process_to_chain, i, j)
-    end
-
-    # Keep track of the (log) acceptance ratios.
-    state.swap_acceptance_ratios[i] = logα
-
-    # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is
-    # adapted before a new `inverse_temperatures` is generated and returned.
-    if adapt
-        ρs = adapt!!(
-            state.adaptation_states, state.chain_to_beta, i, min(one(logα), exp(logα))
-        )
-        @set! state.adaptation_states = ρs
-        @set! state.chain_to_beta = update_inverse_temperatures(ρs, state.chain_to_beta)
-    end
-    return state
-end
diff --git a/src/swapsampler.jl b/src/swapsampler.jl
new file mode 100644
index 0000000..6e43ff1
--- /dev/null
+++ b/src/swapsampler.jl
@@ -0,0 +1,224 @@
+"""
+    SwapSampler <: AbstractMCMC.AbstractSampler
+
+# Fields
+$(FIELDS)
+"""
+struct SwapSampler{S} <: AbstractMCMC.AbstractSampler
+    "swap strategy to use"
+    strategy::S
+end
+
+SwapSampler() = SwapSampler(ReversibleSwap())
+
+swapstrategy(sampler::SwapSampler) = sampler.strategy
+
+# Interaction with the state.
+# NOTE: `SwapSampler` should only every interact with `ProcessOrdering`, so we don't implement `ChainOrdering`.
+function model_for_chain(ordering::ProcessOrder, sampler::SwapSampler, model::MultiModel, state::SwapState, I...)
+    # `model` is expected to be ordered according to process index, hence we map chain index to process index
+    # and extract the model corresponding to said process.
+    return model_for_process(ordering, sampler, model, state, chain_to_process(state, I...))
+end
+
+function model_for_process(::ProcessOrder, sampler::SwapSampler, model::MultiModel, state::SwapState, I...)
+    # `model` is expected to be ordered according to process index, hence we just extract the corresponding index.
+    return model.models[I...]
+end
+
+"""
+    SwapTransition
+
+Transition type for tempered samplers.
+"""
+@concrete struct SwapTransition
+    chain_to_process
+    process_to_chain
+end
+
+function composition_transition(
+    sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler},
+    swaptransition::SwapTransition,
+    outertransition::MultipleTransitions
+)
+    saveall(sampler) && return CompositionTransition(outertransition, swaptransition)
+    # Otherwise we have to re-order the transitions, since without the `swaptransition` there's
+    # no way to recover the true ordering of the transitions.
+    return MultipleTransitions(sort_by_chain(ProcessOrder(), swaptransition, outertransition.transitions))
+end
+
+# NOTE: This does not have an initial `step`! This is because we need
+# states to work with before we can do anything. Hence it only makes
+# sense to use this sampler in composition with other samplers.
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model,
+    sampler::SwapSampler,
+    state::SwapState;
+    kwargs...
+)
+    # Reset state
+    @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios)
+
+    # Perform a swap step.
+    state = swap_step(rng, model, sampler, state)
+    @set! state.total_steps += 1
+
+    # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`.
+    # TODO: What should we return here?
+    return SwapTransition(deepcopy(state.chain_to_process), deepcopy(state.process_to_chain)), state
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler},
+    state;
+    kwargs...
+)
+    # Reminder: a `swap` can be implemented in two different ways:
+    #
+    # 1. Swap the models and leave ordering of (sampler, state)-pair unchanged.
+    # 2. Swap (sampler, state)-pairs and leave ordering of models unchanged.
+    #
+    # (1) has the properties:
+    # + Easy to keep `outerstate` and `swapstate` in sync since their ordering is never changed.
+    # - Ordering of `outerstate` no longer corresponds to ordering of models, i.e. the returned
+    #   `outerstate.states[i]` does no longer correspond to a state targeting `model.models[i]`.
+    #   This will have to be adjusted in the `AbstractMCMC.bundle_samples` before, say, converting
+    #   into a `MCMCChains.Chains`.
+    #
+    # (2) has the properties:
+    # + Returned `outertransition` (and `outerstate`, if we want) has the same ordering as the models,
+    #   i.e. `outerstate.states[i]` now corresponds to `model.models[i]`!
+    # - Need to keep `outerstate` and `swapstate` in sync since their ordering now changes.
+    # - Need to also re-order `outersampler.samplers` :/
+    #
+    # Here (as in, below) we go with option (1), i.e. re-order the `models`.
+    # A full `step` then is as follows:
+    # 1. Sort models according to index processes using the `swapstate` from previous iteration.
+    # 2. Take step with `swapsampler`.
+    # 3. Sort models _again_ according to index processes using the new `swapstate`, since we
+    #    might have made a swap in (2).
+    # 4. Run multi-sampler.
+
+    outersampler, swapsampler = outer_sampler(sampler), inner_sampler(sampler)
+
+    # Get the states.
+    outerstate_prev, swapstate_prev = outer_state(state), inner_state(state)
+
+    # Re-order the models.
+    chain2models = model.models  # but keep the original chain → model around because we'll re-order again later
+    @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate_prev)
+
+    # Step for the swap-sampler.
+    swaptransition, swapstate = AbstractMCMC.step(
+        rng, model, swapsampler, state_from(model, swapstate_prev, outerstate_prev);
+        kwargs...
+    )
+
+    # Re-order the models AGAIN, since we might have swapped some.
+    @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate)
+
+    # Create the current state from `outerstate_prev` and `swapstate`, and `step` for `outersampler`.`
+    outertransition, outerstate = AbstractMCMC.step(
+        # HACK: We really need the `state_from` here despite the fact that `SwapSampler` does note
+        # change the `swapstates.states` itself, but we might require a re-computation of certain
+        # quantities from the `model`, which has now potentially been re-ordered (see above).
+        # NOTE: We do NOT do `state_from(model, outerstate_prev, swapstate)` because as of now,
+        # `swapstate` does not implement `getparams_and_logprob`.
+        rng, model, outersampler, state_from(model, outerstate_prev, outerstate_prev);
+        kwargs...
+    )
+
+    # TODO: Should we re-order the transitions?
+    # Currently, one has to re-order the `outertransition` according to `swaptransition`
+    # in the `bundle_samples`. Is this the right approach though?
+    # TODO: We should at least re-order transitions in the case where `saveall(sampler) == false`!
+    # In this case, we'll just return the transition without the swap-transition, hence making it
+    # impossible to reconstruct the actual ordering!
+    return (
+        composition_transition(sampler, swaptransition, outertransition),
+        composition_state(sampler, swapstate, outerstate)
+    )
+end
+
+# NOTE: The default initial `step` for `CompositionSampler` simply calls the two different
+# `step` methods, but since `SwapSampler` does not have such an implementation this will fail.
+# Instead we overload the initial `step` for `CompositionSampler` involving `SwapSampler` to
+# first take a `step` using the non-swapsampler and then construct `SwapState` from the resulting state.
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler};
+    kwargs...
+)
+    # This should hopefully be a `MultipleStates` or something since we're working with a `MultiModel`.
+    state_outer_initial = last(AbstractMCMC.step(rng, model, outer_sampler(sampler); kwargs...))
+    # NOTE: Since `SwapState` wraps a sequence of states from another sampler, we need `state_outer_initial`
+    # to initialize the `SwapState`.
+    state_inner_initial = SwapState(state_outer_initial)
+
+    # Create the composition state, and take a full step.
+    state = composition_state(sampler, state_inner_initial, state_outer_initial)
+    return AbstractMCMC.step(rng, model, sampler, state; kwargs...)
+end
+
+function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    sampler::CompositionSampler{<:SwapSampler,<:AbstractMCMC.AbstractSampler};
+    kwargs...
+)
+    # This should hopefully be a `MultipleStates` or something since we're working with a `MultiModel`.
+    state_inner_initial = last(AbstractMCMC.step(rng, model, inner_sampler(sampler); kwargs...))
+    # NOTE: Since `SwapState` wraps a sequence of states from another sampler, we need `state_outer_initial`
+    # to initialize the `SwapState`.
+    state_outer_initial = SwapState(state_inner_initial)
+
+    # Create the composition state, and take a full step.
+    state = composition_state(sampler, state_inner_initial, state_outer_initial)
+    return AbstractMCMC.step(rng, model, sampler, state; kwargs...)
+end
+
+@nospecialize function AbstractMCMC.step(
+    rng::Random.AbstractRNG,
+    model::MultiModel,
+    sampler::CompositionSampler{<:SwapSampler,<:SwapSampler};
+    kwargs...
+)
+    error("`SwapSampler` requires states from sampler other than `SwapSampler` to be initialized")
+end
+
+"""
+    swap_attempt(rng, model, sampler, state, i, j)
+
+Attempt to swap the temperatures of two chains by tempering the densities and
+calculating the swap acceptance ratio; then swapping if it is accepted.
+"""
+function swap_attempt(rng::Random.AbstractRNG, model::MultiModel, sampler::SwapSampler, state, i, j)
+    # Extract the relevant transitions.
+    state_i = state_for_chain(state, i)
+    state_j = state_for_chain(state, j)
+    # Evaluate logdensity for both parameters for each tempered density.
+    # NOTE: `SwapSampler` should only be working with models ordered according to `ProcessOrder`,
+    # never `ChainOrder`, hence why we have the below.
+    model_i = model_for_chain(ProcessOrder(), sampler, model, state, i)
+    model_j = model_for_chain(ProcessOrder(), sampler, model, state, j)
+    logπiθi, logπiθj = compute_logdensities(model_i, model_j, state_i, state_j)
+    logπjθj, logπjθi = compute_logdensities(model_j, model_i, state_j, state_i)
+
+    # If the proposed temperature swap is accepted according `logα`,
+    # swap the temperatures for future steps.
+    logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj)
+    should_swap = -Random.randexp(rng) ≤ logα
+    if should_swap
+        swap!(state.chain_to_process, state.process_to_chain, i, j)
+    end
+
+    # Keep track of the (log) acceptance ratios.
+    state.swap_acceptance_ratios[i] = logα
+
+    # TODO: Handle adaptation.
+    return state
+end
diff --git a/src/utils.jl b/src/utils.jl
new file mode 100644
index 0000000..17b9125
--- /dev/null
+++ b/src/utils.jl
@@ -0,0 +1,34 @@
+# TODO: Move.
+chain_to_process(transition::SwapTransition, I...) = transition.chain_to_process[I...]
+
+"""
+    roundtrips(transitions)
+
+Return sequence of `(start_index, turnpoint_index, end_index)`-triples representing roundtrips.
+"""
+function roundtrips(transitions::AbstractVector{<:TemperedTransition})
+    return roundtrips(map(Base.Fix2(getproperty, :swaptransition), transitions))
+end
+function roundtrips(transitions::AbstractVector{<:SwapTransition})
+    result = Tuple{Int,Int,Int}[]
+    start_index, turn_index = 1, nothing
+    for (i, t) in enumerate(transitions)
+        n = length(t.chain_to_process)
+        if isnothing(turn_index)
+            # Looking for the turn.
+            if chain_to_process(t, 1) == n
+                turn_index = i
+            end
+        else
+            # Looking for the return/end.
+            if chain_to_process(t, 1) == 1
+                push!(result, (start_index, turn_index, i))
+                # Reset.
+                start_index = i
+                turn_index = nothing
+            end
+        end
+    end
+
+    return result
+end
diff --git a/test/Project.toml b/test/Project.toml
index 310ee7f..492577c 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -3,12 +3,18 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
 AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
 AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
 MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
+MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
+Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
 
@@ -20,6 +26,6 @@ Bijectors = "0.10"
 Distributions = "0.24, 0.25"
 LogDensityProblems = "2"
 LogDensityProblemsAD = "1"
-MCMCChains = "5.5"
+MCMCChains = "6"
 Turing = "0.24"
 julia = "1"
diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl
new file mode 100644
index 0000000..3cc9c8a
--- /dev/null
+++ b/test/abstractmcmc.jl
@@ -0,0 +1,186 @@
+@testset "AbstractMCMC" begin
+    rng = Random.default_rng()
+    model = DistributionLogDensity(MvNormal(Ones(2), I))
+    logdensity_model = LogDensityModel(model)
+
+    spl = RWMH(MvNormal(Zeros(dimension(model)), I))
+    @test spl isa AbstractMCMC.AbstractSampler
+
+    @testset "CompositionSampler(.., saveall=$(saveall))" for saveall in (true, false, Val(true), Val(false))
+        spl_composed = MCMCTempering.CompositionSampler(spl, spl, saveall)
+
+        num_iters = 100
+        # Taking two steps with `spl` should be equivalent to one step with `spl ∘ spl`.
+        # Use the same initial state.
+        state_initial = last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl))
+        state_composed_initial = MCMCTempering.state_from(
+            logdensity_model,
+            last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl_composed)),
+            state_initial,
+        )
+
+        @test state_composed_initial isa MCMCTempering.CompositionState
+
+        # Take two steps with `spl`.
+        rng = Random.MersenneTwister(42)
+        state = deepcopy(state_initial)
+        for _ = 1:num_iters
+            transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state)
+            transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state)
+        end
+        params, logp = MCMCTempering.getparams_and_logprob(logdensity_model, state)
+
+        # Take one step with `spl ∘ spl`.
+        rng = Random.MersenneTwister(42)
+        state_composed = deepcopy(state_composed_initial)
+        for _ = 1:num_iters
+            transition, state_composed = AbstractMCMC.step(rng, logdensity_model, spl_composed, state_composed)
+
+            # Make sure the state types stay consistent.
+            if MCMCTempering.saveall(spl_composed)
+                @test transition isa MCMCTempering.CompositionTransition
+            end
+            @test state_composed isa MCMCTempering.CompositionState
+        end
+        params_composed, logp_composed = MCMCTempering.getparams_and_logprob(logdensity_model, state_composed)
+
+        # Check that the parameters and log probability are the same.
+        @test params == params_composed
+        @test logp == logp_composed
+
+        # Make sure that `AbstractMCMC.sample` is good.
+        chain_composed = sample(logdensity_model, spl_composed, 2; progress=false, chain_type=MCMCChains.Chains)
+        chain = sample(
+            logdensity_model, spl, MCMCTempering.saveall(spl_composed) ? 4 : 2;
+            progress=false, chain_type=MCMCChains.Chains
+        )
+
+        # Should be the same length because the `SequentialTransitions` will be unflattened.
+        @test chain_composed isa MCMCChains.Chains
+        @test length(chain_composed) == length(chain)
+    end
+
+    @testset "RepeatedSampler(..., saveall=$(saveall))" for saveall in (true, false, Val(true), Val(false))
+        spl_repeated = MCMCTempering.RepeatedSampler(spl, 2, saveall)
+        @test spl_repeated isa MCMCTempering.RepeatedSampler
+
+        # Taking two steps with `spl` should be equivalent to one step with `spl ∘ spl`.
+        # Use the same initial state.
+        state_initial = last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl))
+        state_repeated_initial = MCMCTempering.state_from(
+            logdensity_model,
+            last(AbstractMCMC.step(Random.default_rng(), logdensity_model, spl_repeated)),
+            state_initial,
+        )
+
+        num_iters = 100
+
+        # Take two steps with `spl`.
+        rng = Random.MersenneTwister(42)
+        state = deepcopy(state_initial)
+        for _ = 1:num_iters
+            transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state)
+            transition, state = AbstractMCMC.step(rng, logdensity_model, spl, state)
+        end
+        params, logp = MCMCTempering.getparams_and_logprob(logdensity_model, state)
+
+        # Take one step with `spl ∘ spl`.
+        rng = Random.MersenneTwister(42)
+        state_repeated = deepcopy(state_repeated_initial)
+        for _ = 1:num_iters
+            transition, state_repeated = AbstractMCMC.step(rng, logdensity_model, spl_repeated, state_repeated)
+
+            # Make sure the state types stay consistent.
+            if MCMCTempering.saveall(spl_repeated)
+                @test transition isa MCMCTempering.SequentialTransitions
+                @test state_repeated isa MCMCTempering.SequentialStates
+            else
+                @test state_repeated isa typeof(state_initial)
+            end
+        end
+        params_repeated, logp_repeated = MCMCTempering.getparams_and_logprob(logdensity_model, state_repeated)
+
+        # Check that the parameters and log probability are the same.
+        @test params == params_repeated
+        @test logp == logp_repeated
+
+        # Make sure that `AbstractMCMC.sample` is good.
+        chain_repeated = sample(logdensity_model, spl_repeated, 2; progress=false, chain_type=MCMCChains.Chains)
+        chain = sample(
+            logdensity_model, spl, MCMCTempering.saveall(spl_repeated) ? 4 : 2;
+            progress=false, chain_type=MCMCChains.Chains
+        )
+
+        # Should be the same length because the `SequentialTransitions` will be unflattened.
+        @test length(chain_repeated) == length(chain)
+    end
+
+    @testset "MultiSampler" begin
+        spl_multi = spl × spl
+        @testset "$model_multi" for model_multi in [
+            MCMCTempering.MultiModel((logdensity_model, logdensity_model)),  # tuple
+            MCMCTempering.MultiModel([logdensity_model, logdensity_model]),  # vector
+            MCMCTempering.MultiModel((m for m in [logdensity_model, logdensity_model]))  # iterator
+            ]
+
+            @test spl_multi isa MCMCTempering.MultiSampler
+
+            num_iters = 100
+            # Use the same initial state.
+            states_initial = map(model_multi.models) do model
+                last(AbstractMCMC.step(Random.default_rng(), model, spl))
+            end
+            states_multi_initial = MCMCTempering.state_from(
+                model_multi,
+                last(AbstractMCMC.step(Random.default_rng(), model_multi, spl_multi)),
+                MCMCTempering.MultipleStates(states_initial),
+            )
+
+            # Taking a step with `spl_multi` on `multimodel` should be equivalent
+            # to stepping with the component samplers on the component models.
+            rng = Random.MersenneTwister(42)
+            rng_multi = Random.MersenneTwister(42)
+            states = deepcopy(states_initial)
+            state_multi = deepcopy(states_multi_initial)
+            for _ = 1:num_iters
+                state_multi = last(AbstractMCMC.step(rng_multi, model_multi, spl_multi, state_multi))
+                states = map(model_multi.models, states) do model, state
+                    last(AbstractMCMC.step(rng, model, spl, state))
+                end
+            end
+            params_and_logp = map(Base.Fix1(MCMCTempering.getparams_and_logprob, logdensity_model), states)
+            params_multi, logp_multi = MCMCTempering.getparams_and_logprob(model_multi, state_multi)
+
+            @test map(first, params_and_logp) == params_multi
+            @test map(last, params_and_logp) == logp_multi
+        end
+    end
+
+    @testset "SwapSampler" begin
+        # SwapSampler without tempering (i.e. in a composition and using `MultiModel`, etc.)
+        init_params = [[5.0], [5.0]]
+        mdl1 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(4.9999, 1)))
+        mdl2 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(5.0001, 1)))
+        spl1 = RWMH(MvNormal(Zeros(dimension(mdl1)), I))
+        spl2 = let σ² = 1e-2
+            MALA(∇ -> MvNormal(σ² * ∇, 2σ² * I))
+        end
+        swapspl = MCMCTempering.SwapSampler()
+        spl_full = (spl1 × spl2) ∘ swapspl
+        product_model = LogDensityModel(mdl1) × LogDensityModel(mdl2)
+        # Sample!
+        multisamples = sample(product_model, spl_full, 1000; init_params=init_params, progress=false)
+        # Extract the transitions corresponding to each of the models.
+        model_transitions = mapreduce(hcat, multisamples) do t
+            [MCMCTempering.outer_transition(t).transitions[MCMCTempering.inner_transition(t).process_to_chain]...]
+        end
+        # Make sure we actually got some swaps going and we were using different types of states
+        # for both models.
+        @test length(unique(typeof, model_transitions[1, :])) ≥ 1
+        @test length(unique(typeof, model_transitions[2, :])) ≥ 1
+
+        # Check that means are roughly okay.
+        model_params = map(first ∘ MCMCTempering.getparams, model_transitions)
+        @test vec(mean(model_params; dims=2)) ≈ [5.0, 5.0] atol=0.2
+    end
+end
diff --git a/test/compat.jl b/test/compat.jl
index e85bd05..7052bb1 100644
--- a/test/compat.jl
+++ b/test/compat.jl
@@ -1,6 +1,28 @@
 # AdvancedMH.jl
-MCMCTempering.getparams(transition::AdvancedMH.Transition) = transition.params
-MCMCTempering.getparams(transition::AdvancedMH.GradientTransition) = transition.params
+MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp
+function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp)
+    Setfield.@set! transition.params = params
+    Setfield.@set! transition.lp = lp
+    return transition
+end
+MCMCTempering.getparams_and_logprob(transition::AdvancedMH.GradientTransition) = transition.params, transition.lp
+# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible.
+function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.GradientTransition, params, lp)
+    # NOTE: We have to re-compute the gradient here because this will be used in the subsequent `step` for
+    # the MALA sampler.
+    return AdvancedMH.GradientTransition(params, AdvancedMH.logdensity_and_gradient(model, params)...)
+end
 
 # AdvancedHMC.jl
-MCMCTempering.getparams(t::AdvancedHMC.Transition) = t.z.θ
+MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value
+MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition)
+
+# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible.
+function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp)
+    # NOTE: Need to recompute the gradient because it might be used in the next integration step.
+    hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
+    return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
+        hamiltonian, params, state.transition.z.r;
+        ℓκ=state.transition.z.ℓκ
+    )
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index e13411b..160cabf 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,22 +1,4 @@
-using MCMCTempering
-using Test
-using Distributions
-using AdvancedMH
-using MCMCChains
-using Bijectors
-using LinearAlgebra
-using AbstractMCMC
-using LogDensityProblems: LogDensityProblems, logdensity, logdensity_and_gradient
-using LogDensityProblemsAD
-using ForwardDiff: ForwardDiff
-using AdvancedMH: AdvancedMH
-using AdvancedHMC: AdvancedHMC
-using Turing: Turing, DynamicPPL
-
-
-include("utils.jl")
-include("compat.jl")
-
+include("setup.jl")
 
 """
     test_and_sample_model(model, sampler, inverse_temperatures[, swap_strategy]; kwargs...)
@@ -35,7 +17,8 @@ Several properties of the tempered sampler are tested before returning:
 
 # Keyword arguments
 - `num_iterations`: The number of iterations to run the sampler for. Defaults to `2_000`.
-- `swap_every`: The number of iterations between each swap attempt. Defaults to `2`.
+- `steps_per_swap`: The number of iterations between each swap attempt. Defaults to `1`.
+- `adapt`: Whether to adapt the sampler. Defaults to `false`.
 - `adapt_target`: The target acceptance rate for the swaps. Defaults to `0.234`.
 - `adapt_rtol`: The relative tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.1`.
 - `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`.
@@ -44,59 +27,63 @@ Several properties of the tempered sampler are tested before returning:
 - `init_params`: The initial parameters to use for the sampler. Defaults to `nothing`.
 - `param_names`: The names of the parameters in the chain; used to construct the resulting chain. Defaults to `missing`.
 - `progress`: Whether to show a progress bar. Defaults to `false`.
-- `kwargs...`: Additional keyword arguments to pass to `MCMCTempering.tempered`.
 """
 function test_and_sample_model(
     model,
     sampler,
-    inverse_temperatures,
-    swap_strategy=MCMCTempering.SingleSwap();
+    inverse_temperatures;
+    swap_strategy=MCMCTempering.SingleSwap(),
     mean_swap_rate_bound=0.1,
     compare_mean_swap_rate=≥,
     num_iterations=2_000,
-    swap_every=2,
+    steps_per_swap=1,
+    adapt=false,
     adapt_target=0.234,
     adapt_rtol=0.1,
     adapt_atol=0.05,
     init_params=nothing,
     param_names=missing,
     progress=false,
-    kwargs...
+    minimum_roundtrips=nothing
 )
-    # TODO: Remove this when no longer necessary.
-    num_iterations_tempered = Int(ceil(num_iterations * swap_every / (swap_every - 1)))
-
     # Make the tempered sampler.
     sampler_tempered = tempered(
         sampler,
         inverse_temperatures;
         swap_strategy=swap_strategy,
-        swap_every=swap_every,
+        steps_per_swap=steps_per_swap,
         adapt_target=adapt_target,
-        kwargs...
     )
 
+    @test sampler_tempered.swapstrategy == swap_strategy
+    @test MCMCTempering.swapsampler(sampler_tempered).strategy == swap_strategy
+
     # Store the states.
     states_tempered = []
     callback = StateHistoryCallback(states_tempered)
 
     # Sample.
     samples_tempered = AbstractMCMC.sample(
-        model, sampler_tempered, num_iterations_tempered;
+        model, sampler_tempered, num_iterations;
         callback=callback, progress=progress, init_params=init_params
     )
 
+    if !isnothing(minimum_roundtrips)
+        # Make sure we've had at least some roundtrips.
+        @test length(MCMCTempering.roundtrips(samples_tempered)) ≥ minimum_roundtrips
+    end
+
     # Let's make sure the process ↔ chain mapping is valid.
     numtemps = MCMCTempering.numtemps(sampler_tempered)
-    for state in states_tempered
-        for i = 1:numtemps
+    @test all(states_tempered) do state
+        all(1:numtemps) do i
             # These two should be inverses of each other.
-            @test MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i
+            MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i
         end
     end
 
     # Extract the states that were swapped.
-    states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered)
+    states_swapped = map(Base.Fix2(getproperty, :swapstate), states_tempered)
     # Swap acceptance ratios should be compared against the target acceptance in case of adaptation.
     swap_acceptance_ratios = mapreduce(
         collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios),
@@ -115,25 +102,39 @@ function test_and_sample_model(
     end
 
     # Extract the history of chain indices.
-    process_to_chain_history_list = map(states_tempered) do state
+    process_to_chain_history_list = map(states_swapped) do state
         state.process_to_chain
     end
     process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1))
 
     # Check that the swapping has been done correctly.
-    process_to_chain_uniqueness = map(states_tempered) do state
+    process_to_chain_uniqueness = map(states_swapped) do state
         length(unique(state.process_to_chain)) == length(state.process_to_chain)
     end
     @test all(process_to_chain_uniqueness)
 
-    # For the currently implemented strategies, the index process should not move by more than 1.
-    @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1)
+    # For every strategy except `RandomSwap`, the index process should not move by more than 1.
+    if !(swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.RandomSwap})
+        @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1)
+    end
 
-    chain_to_process_uniqueness = map(states_tempered) do state
+    chain_to_process_uniqueness = map(states_swapped) do state
         length(unique(state.chain_to_process)) == length(state.chain_to_process)
     end
     @test all(chain_to_process_uniqueness)
 
+    # Compare the tempered sampler to the untempered sampler.
+    state_tempered = states_tempered[end]
+    chain_tempered = AbstractMCMC.bundle_samples(
+        # TODO: Just use the underlying chain?
+        samples_tempered,
+        MCMCTempering.maybe_wrap_model(model),
+        sampler_tempered,
+        state_tempered,
+        MCMCChains.Chains;
+        param_names=param_names
+    )
+
     # Tests that we have at least swapped some times (say at least 10% of attempted swaps).
     swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row
         # Some of the strategies performs multiple swaps in a swap-iteration,
@@ -141,20 +142,12 @@ function test_and_sample_model(
         # i.e. only count non-zero elements in a row _once_. Hence the `min`.
         min(1, sum(abs, row))
     end
+
+    num_nonswap_steps_taken = length(chain_tempered)
+    @test num_nonswap_steps_taken == (num_iterations * steps_per_swap)
     @test compare_mean_swap_rate(
         sum(swap_success_indicators),
-        (num_iterations_tempered / swap_every) * mean_swap_rate_bound
-    )
-
-    # Compare the tempered sampler to the untempered sampler.
-    state_tempered = states_tempered[end]
-    chain_tempered = AbstractMCMC.bundle_samples(
-        samples_tempered[findall((!).(getproperty.(states_tempered, :is_swap)))],
-        MCMCTempering.maybe_wrap_model(model),
-        sampler_tempered.sampler,
-        MCMCTempering.state_for_chain(state_tempered),
-        MCMCChains.Chains;
-        param_names=param_names
+        (num_nonswap_steps_taken / steps_per_swap) * mean_swap_rate_bound
     )
 
     return chain_tempered
@@ -163,37 +156,29 @@ end
 function compare_chains(
     chain::MCMCChains.Chains, chain_tempered::MCMCChains.Chains;
     atol=1e-6, rtol=1e-6,
-    compare_std=true,
     compare_ess=true,
+    compare_ess_slack=0.5, # HACK: this is very low which is unnecessary in most cases, but it's too random
     isbroken=false
 )
-    desc = describe(chain)[1].nt
-    desc_tempered = describe(chain_tempered)[1].nt
+    mean = to_dict(MCMCChains.mean(chain))
+    mean_tempered = to_dict(MCMCChains.mean(chain_tempered))
 
     # Compare the means.
     if isbroken
-        @test_broken desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol
+        @test_broken all(isapprox(mean[sym], mean_tempered[sym]; atol, rtol) for sym in keys(mean))
     else
-        @test desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol
-    end
-
-    # Compare the std. of the chains.
-    if compare_std
-        if isbroken
-            @test_broken desc.std ≈ desc_tempered.std atol = atol rtol = rtol
-        else
-            @test desc.std ≈ desc_tempered.std atol = atol rtol = rtol
-        end
+        @test all(isapprox(mean[sym], mean_tempered[sym]; atol, rtol) for sym in keys(mean))
     end
 
     # Compare the ESS.
     if compare_ess
-        ess = MCMCChains.ess_rhat(chain).nt.ess
-        ess_tempered = MCMCChains.ess_rhat(chain_tempered).nt.ess
+        ess = to_dict(MCMCChains.ess(chain))
+        ess_tempered = to_dict(MCMCChains.ess(chain_tempered))
+        @info "" ess ess_tempered
         if isbroken
-            @test_broken all(ess_tempered .≥ ess)
+            @test_broken all(ess_tempered[sym] ≥ ess[sym] * compare_ess_slack for sym in keys(ess))
         else
-            @test all(ess_tempered .≥ ess)
+            @test all(ess_tempered[sym] ≥ ess[sym] * compare_ess_slack for sym in keys(ess))
         end
     end
 end
@@ -213,7 +198,7 @@ end
         chain_to_beta = [1.0, 0.75, 0.5, 0.25]
 
         # Make swap chain 1 (now on process 1) ↔ chain 2 (now on process 2)
-        MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1, 2)
+        MCMCTempering.swap!(chain_to_process, process_to_chain, 1, 2)
         # Expected result: chain 1 is now on process 2, chain 2 is now on process 1.
         target_process_to_chain = [2, 1, 3, 4]
         @test process_to_chain[chain_to_process] == 1:length(process_to_chain)
@@ -229,7 +214,7 @@ end
         end
 
         # Make swap chain 2 (now on process 1) ↔ chain 3 (now on process 3)
-        MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2, 3)
+        MCMCTempering.swap!(chain_to_process, process_to_chain, 2, 3)
         # Expected result: chain 3 is now on process 1, chain 2 is now on process 3.
         target_process_to_chain = [3, 1, 2, 4]
         @test process_to_chain[chain_to_process] == 1:length(process_to_chain)
@@ -246,7 +231,7 @@ end
     end
 
     @testset "Simple MvNormal with no expected swaps" begin
-        num_iterations = 10_000
+        num_iterations = 5_000
         d = 1
         model = DistributionLogDensity(MvNormal(ones(d), I))
 
@@ -259,25 +244,24 @@ end
             sampler_rwmh,
             [1.0, 1e-3],  # extreme temperatures -> don't exect much swapping to occur
             num_iterations=num_iterations,
-            swap_every=2,
             adapt=false,
-            init_params = [[0.0], [1000.0]],  # initialized far apart
-            # At most 1% of swaps should be successful.
+            init_params=[[0.0], [1000.0]],  # initialized far apart
+            # At MOST 1% of swaps should be successful.
             mean_swap_rate_bound=0.01,
             compare_mean_swap_rate=≤,
         )
         # `atol` is fairly high because we haven't run this for "too" long. 
-        @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.2
+        @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.3
     end
 
     @testset "GMM 1D" begin
-        num_iterations = 10_000
+        num_iterations = 1_000
         model = DistributionLogDensity(
             MixtureModel(Normal, [(-3, 1.5), (3, 1.5), (15, 1.5), (90, 1.5)], [0.175, 0.25, 0.275, 0.3])
         )
 
         # Setup non-tempered.
-        sampler_rwmh = RWMH(MvNormal(0.1 * ones(1)))
+        sampler_rwmh = RWMH(MvNormal(0.1 * Diagonal(Ones(1))))
 
         # Simple geometric ladder
         inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.95 .^ (0:20))
@@ -287,13 +271,15 @@ end
             model,
             sampler_rwmh,
             inverse_temperatures,
+            swap_strategy=MCMCTempering.NonReversibleSwap(),
             num_iterations=num_iterations,
-            swap_every=2,
             adapt=false,
             # At least 25% of swaps should be successful.
             mean_swap_rate_bound=0.25,
             compare_mean_swap_rate=≥,
             progress=false,
+            # Make sure we have _some_ roundtrips.
+            minimum_roundtrips=10,
         )
 
         # # Compare the chains.
@@ -302,8 +288,7 @@ end
 
     @testset "MvNormal 2D with different swap strategies" begin
         d = 2
-        num_iterations = 20_000
-        swap_every = 2
+        num_iterations = 5_000
 
         μ_true = [-5.0, 5.0]
         σ_true = [1.0, √(10.0)]
@@ -331,17 +316,19 @@ end
             MCMCTempering.RandomSwap()
         ]
 
-        @testset "$(swapstrategy)" for swapstrategy in swapstrategies
+        @testset "$(swap_strategy)" for swap_strategy in swapstrategies
             chain_tempered = test_and_sample_model(
                 model,
                 sampler,
                 inverse_temperatures,
                 num_iterations=num_iterations,
-                swap_every=swap_every,
-                swapstrategy=swapstrategy,
+                swap_strategy=swap_strategy,
                 adapt=false,
+                # Make sure we have _some_ roundtrips.
+                minimum_roundtrips=10,
             )
-            compare_chains(chain, chain_tempered, rtol=0.1, compare_std=false, compare_ess=true)
+
+            compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true)
         end
     end
 
@@ -368,15 +355,14 @@ end
         end
 
         @testset "AdvancedHMC.jl" begin
-            num_iterations = 2_000
+            num_iterations = 5_000
 
             # Set up HMC smpler.
             initial_ϵ = 0.1
             integrator = AdvancedHMC.Leapfrog(initial_ϵ)
             proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator)
             metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model))
-            adaptor = AdvancedHMC.StanHMCAdaptor(AdvancedHMC.MassMatrixAdaptor(metric), AdvancedHMC.StepSizeAdaptor(0.8, integrator))
-            sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, adaptor)
+            sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric)
 
             # Sample using HMC.
             samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false)
@@ -386,28 +372,47 @@ end
             )
             map_parameters!(b, chain_hmc)
 
+            # Make sure that we get the "same" result when only using the inverse temperature 1.
+            sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1])
+            chain_tempered = sample(
+                model, sampler_tempered, num_iterations;
+                init_params=copy(init_params),
+                chain_type=MCMCChains.Chains,
+                param_names=param_names,
+                progress=false,
+            )
+            map_parameters!(b, chain_tempered)
+            compare_chains(
+                chain_hmc, chain_tempered;
+                atol=0.2,
+                compare_ess=true,
+                isbroken=false
+            )
+
             # Sample using tempered HMC.
             chain_tempered = test_and_sample_model(
                 model,
                 sampler_hmc,
-                [1, 0.25, 0.1, 0.01],
+                [1, 0.75, 0.5, 0.25, 0.1, 0.01],
                 swap_strategy=MCMCTempering.ReversibleSwap(),
                 num_iterations=num_iterations,
-                swap_every=10,
                 adapt=false,
-                mean_swap_rate_bound=0,
+                mean_swap_rate_bound=0.1,
                 init_params=copy(init_params),
                 param_names=param_names,
                 progress=false
             )
             map_parameters!(b, chain_tempered)
-
-            # TODO: Make it not broken, i.e. produce reasonable results.
-            compare_chains(chain_hmc, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false)
+            compare_chains(
+                chain_hmc, chain_tempered;
+                atol=0.3,
+                compare_ess=true,
+                isbroken=false,
+            )
         end
-        
+
         @testset "AdvancedMH.jl" begin
-            num_iterations = 100_000
+            num_iterations = 10_000
             d = LogDensityProblems.dimension(model)
 
             # Set up MALA sampler.
@@ -415,33 +420,51 @@ end
             sampler_mh = MALA(∇ -> MvNormal(σ² * ∇, 2σ² * I))
 
             # Sample using MALA.
-            samples_mh = AbstractMCMC.sample(
+            chain_mh = AbstractMCMC.sample(
                 model, sampler_mh, num_iterations;
-                init_params=copy(init_params), progress=false
-            )
-            chain_mh = AbstractMCMC.bundle_samples(
-                samples_mh, MCMCTempering.maybe_wrap_model(model), sampler_mh, samples_mh[1], MCMCChains.Chains;
-                param_names=param_names
+                init_params=copy(init_params),
+                progress=false,
+                chain_type=MCMCChains.Chains,
+                param_names=param_names,
             )
             map_parameters!(b, chain_mh)
 
-            # Sample using tempered MALA.
+            # Make sure that we get the "same" result when only using the inverse temperature 1.
+            sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1])
+            chain_tempered = sample(
+                model, sampler_tempered, num_iterations;
+                init_params=copy(init_params),
+                chain_type=MCMCChains.Chains,
+                param_names=param_names,
+                progress=false,
+            )
+            map_parameters!(b, chain_tempered)
+            compare_chains(
+                chain_mh, chain_tempered;
+                atol=0.2,
+                compare_ess=true,
+                isbroken=false,
+            )
+
+            # Sample using actual tempering.
             chain_tempered = test_and_sample_model(
                 model,
                 sampler_mh,
                 [1, 0.9, 0.75, 0.5, 0.25, 0.1],
                 swap_strategy=MCMCTempering.ReversibleSwap(),
                 num_iterations=num_iterations,
-                swap_every=2,
                 adapt=false,
-                mean_swap_rate_bound=0,
+                mean_swap_rate_bound=0.1,
                 init_params=copy(init_params),
                 param_names=param_names
             )
             map_parameters!(b, chain_tempered)
             
             # Need a large atol as MH is not great on its own
-            compare_chains(chain_mh, chain_tempered, atol=0.4, compare_std=false, compare_ess=true, isbroken=false)
+            compare_chains(chain_mh, chain_tempered, atol=0.2, compare_ess=true, isbroken=false)
         end
     end
+
+    include("abstractmcmc.jl")
+    include("simple_gaussian.jl")
 end
diff --git a/test/setup.jl b/test/setup.jl
new file mode 100644
index 0000000..90195ca
--- /dev/null
+++ b/test/setup.jl
@@ -0,0 +1,22 @@
+using MCMCTempering
+using Test
+using Distributions
+using AdvancedMH
+using MCMCChains
+using Bijectors
+using LinearAlgebra
+using FillArrays
+using Setfield: Setfield
+using AbstractMCMC: AbstractMCMC, LogDensityModel
+using LogDensityProblems: LogDensityProblems, logdensity, logdensity_and_gradient, dimension
+using LogDensityProblemsAD
+using Random: Random
+using ForwardDiff: ForwardDiff
+using AdvancedMH: AdvancedMH
+using AdvancedHMC: AdvancedHMC
+using Turing: Turing, DynamicPPL
+
+
+include("utils.jl")
+include("test_utils.jl")
+include("compat.jl")
diff --git a/test/simple_gaussian.jl b/test/simple_gaussian.jl
new file mode 100644
index 0000000..1ff19b8
--- /dev/null
+++ b/test/simple_gaussian.jl
@@ -0,0 +1,74 @@
+@testset "Simple tempered Gaussian (closed form)" begin
+    μ = Zeros(1)
+    inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.8 .^ (0:10))
+    variances_true = inv.(inverse_temperatures)
+    std_true_dict = map(variances_true) do v
+        Dict(:param_1 => √v)
+    end
+    tempered_dists = [MvNormal(Zeros(1), I / β) for β in inverse_temperatures]
+    tempered_multimodel = MCMCTempering.MultiModel(map(LogDensityModel ∘ DistributionLogDensity, tempered_dists))
+
+    init_params = zeros(length(μ))
+
+    num_samples = 1_000
+    num_burnin = num_samples ÷ 2
+    thin = 10
+
+    # Samplers.
+    rwmh = RWMH(MvNormal(Zeros(1), I))
+    rwmh_tempered = TemperedSampler(rwmh, inverse_temperatures)
+    rwmh_product = MCMCTempering.MultiSampler(Fill(rwmh, length(tempered_dists)))
+    rwmh_product_with_swap = rwmh_product ∘ MCMCTempering.SwapSampler()
+
+    # Sample.
+    @testset "TemperedSampler" begin
+        chains_product = sample(
+            DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples;
+            init_params,
+            bundle_resolve_swaps=true,
+            chain_type=Vector{MCMCChains.Chains},
+            progress=false,
+            discard_initial=num_burnin,
+            thinning=thin,
+        )
+        test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
+    end
+
+    @testset "MultiSampler without swapping" begin
+        chains_product = sample(
+            tempered_multimodel, rwmh_product, num_samples;
+            init_params,
+            chain_type=Vector{MCMCChains.Chains},
+            progress=false,
+            discard_initial=num_burnin,
+            thinning=thin,
+        )
+        test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
+    end
+
+    @testset "MultiSampler with swapping (saveall=true)" begin
+        chains_product = sample(
+            tempered_multimodel, rwmh_product_with_swap, num_samples;
+            init_params,
+            bundle_resolve_swaps=true,
+            chain_type=Vector{MCMCChains.Chains},
+            progress=false,
+            discard_initial=num_burnin,
+            thinning=thin,
+        )
+        test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
+    end
+
+    @testset "MultiSampler with swapping (saveall=true)" begin
+        chains_product = sample(
+            tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples;
+            init_params,
+            chain_type=Vector{MCMCChains.Chains},
+            progress=false,
+            discard_initial=num_burnin,
+            thinning=thin,
+        )
+        test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
+    end
+end
+
diff --git a/test/test_utils.jl b/test/test_utils.jl
new file mode 100644
index 0000000..d343a8e
--- /dev/null
+++ b/test/test_utils.jl
@@ -0,0 +1,126 @@
+using MCMCDiagnosticTools, Statistics, DataFrames
+
+"""
+    to_dict(c::MCMCChains.Chains[, col::Symbol])
+
+Return a dictionary mapping parameter names to the values in column `col` of `c`.
+
+# Arguments
+- `c`: A `MCMCChains.Chains` object.
+- `col`: The column to extract values from. Defaults to the first column that is not `:parameters`.
+"""
+to_dict(c::MCMCChains.ChainDataFrame) = to_dict(c, first(filter(!=(:parameters), keys(c.nt))))
+function to_dict(c::MCMCChains.ChainDataFrame, col::Symbol)
+    df = DataFrame(c)
+    return Dict(sym => df[findfirst(==(sym), df[:, :parameters]), col] for sym in df.parameters)
+end
+
+"""
+    atol_for_chain(chain; significance=1e-3, kind=Statistics.mean)
+
+Return a dictionary of absolute tolerances for each parameter in `chain`, computed
+as the confidence interval width for the mean of the parameter with `significance`.
+"""
+function atol_for_chain(chain; significance=1e-3, kind=Statistics.mean)
+    param_names = names(chain, :parameters)
+    # Can reject H0 if, say, `abs(mean(chain2) - mean(chain1)) > confidence_width`.
+    # Or alternatively, compare means but with `atol` set to the `confidence_width`.
+    # NOTE: Failure to reject, i.e. passing the tests, does not imply that the means are equal.
+    mcse = to_dict(MCMCChains.mcse(chain; kind), :mcse)
+    return Dict(sym => quantile(Normal(0, mcse[sym]), 1 - significance/2) for sym in param_names)
+end
+
+thin_to(chain, n) = chain[1:length(chain) ÷ n:end]
+
+"""
+    test_means(chain, mean_true; kwargs...)
+
+Test that the mean of each parameter in `chain` is approximately `mean_true`.
+
+# Arguments
+- `chain`: A `MCMCChains.Chains` object.
+- `mean_true`: A `Real` or `AbstractDict` mapping parameter names to their true mean.
+- `kwargs...`: Passed to `atol_for_chain`.
+"""
+function test_means(chain::MCMCChains.Chains, mean_true::Real; kwargs...)
+    return test_means(chain, Dict(sym => mean_true for sym in names(chain, :parameters)); kwargs...)
+end
+function test_means(chain::MCMCChains.Chains, mean_true::AbstractDict; n=length(chain), kwargs...)
+    chain = thin_to(chain, n)
+    atol = atol_for_chain(chain; kwargs...)
+    @test all(isapprox(mean(chain[sym]), 0, atol=atol[sym]) for sym in names(chain, :parameters))
+end
+
+"""
+    test_std(chain, std_true; kwargs...)
+
+Test that the standard deviation of each parameter in `chain` is approximately `std_true`.
+
+# Arguments
+- `chain`: A `MCMCChains.Chains` object.
+- `std_true`: A `Real` or `AbstractDict` mapping parameter names to their true standard deviation.
+- `kwargs...`: Passed to `atol_for_chain`.
+"""
+function test_std(chain::MCMCChains.Chains, std_true::Real; kwargs...)
+    return test_std(chain, Dict(sym => std_true for sym in names(chain, :parameters)); kwargs...)
+end
+function test_std(chain::MCMCChains.Chains, std_true::AbstractDict; n=length(chain), kwargs...)
+    chain = thin_to(chain, n)
+    atol = atol_for_chain(chain; kind=Statistics.std, kwargs...)
+    @info "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)]
+    @test all(isapprox(std(chain[sym]), std_true[sym], atol=atol[sym]) for sym in names(chain, :parameters))
+end
+
+"""
+    test_std_monotonicity(chains; isbroken=false, kwargs...)
+
+Test that the standard deviation of each parameter in `chains` is monotonically increasing.
+
+# Arguments
+- `chains`: A vector of `MCMCChains.Chains` objects.
+- `isbroken`: If `true`, then the test will be marked as broken.
+- `kwargs...`: Passed to `atol_for_chain`.
+"""
+function test_std_monotonicity(chains::AbstractVector{<:MCMCChains.Chains}; isbroken::Bool=false, kwargs...)
+    param_names = names(first(chains), :parameters)
+    # We should technically use a Bonferroni-correction here, but whatever.
+    atols = [atol_for_chain(chain; kind=Statistics.std, kwargs...) for chain in chains]
+    stds = [Dict(sym => std(chain[sym]) for sym in param_names) for chain in chains]
+
+    num_chains = length(chains)
+    lbs = [Dict(sym => stds[i][sym] - atols[i][sym] for sym in param_names) for i in 1:num_chains]
+    ubs = [Dict(sym => stds[i][sym] + atols[i][sym] for sym in param_names) for i in 1:num_chains]
+
+    for i = 2:num_chains
+        for sym in param_names
+            # If the upper-bound of the current is smaller than the lower-bound of the previous, then
+            # we can reject the null hypothesis that they are orderd.
+            if isbroken
+                @test_broken ubs[i][sym] ≥ lbs[i - 1][sym]
+            else
+                @test ubs[i][sym] ≥ lbs[i - 1][sym]
+            end
+        end
+    end
+end
+
+"""
+    test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-3, kwargs...)
+
+Test that the mean and standard deviation of each parameter in `chains` is approximately `mean_true`
+and `std_true`, respectively. Also test that the standard deviation is monotonically increasing.
+
+# Arguments
+- `chains`: A vector of `MCMCChains.Chains` objects.
+- `mean_true`: A vector of `Real` or `AbstractDict` mapping parameter names to their true mean.
+- `std_true`: A vector of `Real` or `AbstractDict` mapping parameter names to their true standard deviation.
+- `significance`: The significance level of the test.
+- `kwargs...`: Passed to `atol_for_chain`.
+"""
+function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-4, kwargs...)
+    @testset "chain $i" for i = 1:length(chains)
+        test_means(chains[i], mean_true[i]; kwargs...)
+        test_std(chains[i], std_true[i]; kwargs...)
+    end
+    test_std_monotonicity(chains; significance=0.05)
+end