Skip to content

Commit

Permalink
Deepcopy adaptor before starting sampling
Browse files Browse the repository at this point in the history
This avoids the unintuitive behaviour seen in #379
  • Loading branch information
penelopeysm committed Nov 5, 2024
1 parent a6f0621 commit 2a5731c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
36 changes: 36 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,42 @@ function sample(
verbose::Bool = true,
progress::Bool = false,
(pm_next!)::Function = pm_next!,
) where {T<:AbstractVecOrMat{<:AbstractFloat}}
# Prevent adaptor from being mutated
adaptor = deepcopy(adaptor)
# Then call sample_mutating_adaptor with the same arguments
return sample_mutating_adaptor(
rng,
h,
κ,
θ,
n_samples,
adaptor,
n_adapts;
drop_warmup = drop_warmup,
verbose = verbose,
progress = progress,
(pm_next!) = pm_next!,
)
end

"""
sample_mutating_adaptor(args...; kwargs...)
The same as `sample`, but mutates the `adaptor` argument.
"""
function sample_mutating_adaptor(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
κ::HMCKernel,
θ::T,
n_samples::Int,
adaptor::AbstractAdaptor = NoAdaptation(),
n_adapts::Int = min(div(n_samples, 10), 1_000);
drop_warmup = false,
verbose::Bool = true,
progress::Bool = false,
(pm_next!)::Function = pm_next!,
) where {T<:AbstractVecOrMat{<:AbstractFloat}}
@assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase."
# Prepare containers to store sampling results
Expand Down
3 changes: 2 additions & 1 deletion test/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function runnuts(ℓπ, metric; n_samples = 3_000)
integrator = AdvancedHMC.make_integrator(nuts, step_size)
κ = AdvancedHMC.make_kernel(nuts, integrator)
adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator)
samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose = false)
# Use mutating version of sample() here
samples, stats = AdvancedHMC.sample_mutating_adaptor(rng, h, κ, θ_init, n_samples, adaptor, n_adapts; verbose = false)
return (samples = samples, stats = stats, adaptor = adaptor)
end

Expand Down
30 changes: 30 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ end
n_steps = 10
n_samples = 22_000
n_adapts = 4_000

@testset "$metricsym" for (metricsym, metric) in Dict(
:UnitEuclideanMetric => UnitEuclideanMetric(D),
:DiagEuclideanMetric => DiagEuclideanMetric(D),
Expand Down Expand Up @@ -157,6 +158,7 @@ end
end
end
end

@testset "drop_warmup" begin
nuts = NUTS(0.8)
metric = DiagEuclideanMetric(D)
Expand Down Expand Up @@ -191,4 +193,32 @@ end
@test length(samples) == n_samples
@test length(stats) == n_samples
end

@testset "reproducibility" begin
# Multiple calls to sample() should yield the same results
nuts = NUTS(0.8)
metric = DiagEuclideanMetric(D)
h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
integrator = Leapfrog(ϵ)
κ = AdvancedHMC.make_kernel(nuts, integrator)
adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator)

all_samples = []
for i = 1:5
samples, stats = sample(
Random.MersenneTwister(42),
h,
κ,
θ_init,
100, # n_samples -- don't need so many
adaptor,
50, # n_adapts -- likewise
verbose = false,
progress = false,
drop_warmup = true,
)
push!(all_samples, samples)
end
@test all(map(s -> s all_samples[1], all_samples[2:end]))
end
end

0 comments on commit 2a5731c

Please sign in to comment.