Skip to content

Commit

Permalink
Adopt AbstractMCMC.jl interface (#259)
Browse files Browse the repository at this point in the history
* initial work on adopting the AbstractMCMC interface

* removed old sample method since is now redundant

* AbstractMCMCKernel is now a subtype of AbstractSampler

* fixed impl and added back old one

* overload AbstractMCMC.sample to make interface a bit nicer

* reverted renaming of abstract kernel type

* replicate logging behavior of current AHMC when using AbstractMCMC

* initial work on tests, but require update to testing deps

* moved abstractmcmc interface into separate file

* version bump

* only bump patch version to see if tests succeed

* move away from using extras in Project.toml

* added integration tests for Turing.jl

* removed usage of Turing.jl and MCMCDebugging.jl in main testsuite

* fixed bug in deprecated HMCDA constructor

* allow specification of which testing suites to run

* added Turing.jl integration tests to CI

* fixed name for integration tests

* added using AdvancedHMC in runtests.jl

* removed some now unnecessary usings

* fixed a bug in the downstream testing

* give integration tests a separate CI

* forgot to remove the continue-on-error from CI

* added convenient constructor for DifferentiableDensityModel using Hamiltonians defaults

* fixed tests for AbstractMCMC interface

* added a bunch of docstrings

* bumped minor version

* increased number of samples used in abstractmcmc tests

* remove thinning from tests

* make initial Leapfrog step size smaller

* mistakenly removed AbstractMCMC as a test dep in previous commit

* increase adaptation to see if it helps

* ensure we drop the adaptation samples in the test

* made a mistake apparently

* think I finally fixed the tests

* disable progress in test
  • Loading branch information
torfjelde authored Jul 15, 2021
1 parent 564af93 commit 7cad9f0
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 6 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.2.28"
version = "0.3.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
Expand All @@ -17,6 +18,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractMCMC = "3"
ArgCheck = "1, 2"
DocStringExtensions = "0.8"
InplaceOps = "0.3"
Expand Down
6 changes: 6 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const DEBUG = convert(Bool, parse(Int, get(ENV, "DEBUG_AHMC", "0")))
using Statistics: mean, var, middle
using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
using StatsFuns: logaddexp, logsumexp
import Random
using Random: GLOBAL_RNG, AbstractRNG
using ProgressMeter: ProgressMeter
using UnPack: @unpack
Expand All @@ -16,6 +17,8 @@ using ArgCheck: @argcheck

using DocStringExtensions

import AbstractMCMC

import StatsBase: sample

include("utilities.jl")
Expand Down Expand Up @@ -128,6 +131,9 @@ include("diagnosis.jl")
include("sampler.jl")
export sample

include("abstractmcmc.jl")
export DifferentiableDensityModel

include("contrib/ad.jl")

### Init
Expand Down
293 changes: 293 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
"""
HMCSampler
A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl.
# Fields
$(FIELDS)
# Notes
Note that all the fields have the prefix `initial_` to indicate
that these will not necessarily correspond to the `kernel`, `metric`,
and `adaptor` after sampling.
To access the updated fields use the resulting [`HMCState`](@ref).
"""
struct HMCSampler{K, M, A} <: AbstractMCMC.AbstractSampler
"Initial [`AbstractMCMCKernel`](@ref)."
initial_kernel::K
"Initial [`AbstractMetric`](@ref)."
initial_metric::M
"Initial [`AbstractAdaptor`](@ref)."
initial_adaptor::A
end
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation())

"""
DifferentiableDensityModel(ℓπ, ∂ℓπ∂θ)
DifferentiableDensityModel(ℓπ, m::Module)
A `AbstractMCMC.AbstractMCMCModel` representing a differentiable log-density.
If a module `m` is given as the second argument, then `m` is assumed to be an
automatic-differentiation package and this will be used to compute the gradients.
Note that the module `m` must be imported before usage, e.g.
```julia
using Zygote: Zygote
model = DifferentiableDensityModel(ℓπ, Zygote)
```
results in a `model` which will use Zygote.jl as its AD-backend.
# Fields
$(FIELDS)
"""
struct DifferentiableDensityModel{Tlogπ, T∂logπ∂θ} <: AbstractMCMC.AbstractModel
"Log-density. Maps `AbstractArray` to value of the log-density."
ℓπ::Tlogπ
"Gradient of log-density. Returns a tuple of `ℓπ` and the gradient evaluated at the given point."
∂ℓπ∂θ::T∂logπ∂θ
end

struct DummyMetric <: AbstractMetric end
function DifferentiableDensityModel(ℓπ, m::Module)
h = Hamiltonian(DummyMetric(), ℓπ, m)
return DifferentiableDensityModel(h.ℓπ, h.∂ℓπ∂θ)
end

"""
HMCState
Represents the state of a [`HMCSampler`](@ref).
# Fields
$(FIELDS)
"""
struct HMCState{
TTrans<:Transition,
TMetric<:AbstractMetric,
TKernel<:AbstractMCMCKernel,
TAdapt<:Adaptation.AbstractAdaptor
}
"Index of current iteration."
i::Int
"Current [`Transition`](@ref)."
transition::TTrans
"Current [`AbstractMetric`](@ref), possibly adapted."
metric::TMetric
"Current [`AbstractMCMCKernel`](@ref)."
κ::TKernel
"Current [`AbstractAdaptor`](@ref)."
adaptor::TAdapt
end

"""
$(TYPEDSIGNATURES)
A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref).
"""
function AbstractMCMC.sample(
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
N::Integer;
kwargs...
)
return AbstractMCMC.sample(Random.GLOBAL_RNG, model, kernel, metric, adaptor, N; kwargs...)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
N::Integer;
progress = true,
verbose = false,
callback = nothing,
kwargs...
)
sampler = HMCSampler(kernel, metric, adaptor)
if callback === nothing
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
progress = false # don't use AMCMC's progress-funtionality
end

return AbstractMCMC.mcmcsample(
rng, model, sampler, N;
progress = progress,
verbose = verbose,
callback = callback,
kwargs...
)
end

function AbstractMCMC.sample(
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
nchains::Integer;
kwargs...
)
return AbstractMCMC.sample(
Random.GLOBAL_RNG, model, kernel, metric, adaptor, N, nchains;
kwargs...
)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
nchains::Integer;
progress = true,
verbose = false,
callback = nothing,
kwargs...
)
sampler = HMCSampler(kernel, metric, adaptor)
if callback === nothing
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
progress = false # don't use AMCMC's progress-funtionality
end

return AbstractMCMC.mcmcsample(
rng, model, sampler, parallel, N, nchains;
progress = progress,
verbose = verbose,
callback = callback,
kwargs...
)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
spl::HMCSampler;
init_params = nothing,
kwargs...
)
metric = spl.initial_metric
κ = spl.initial_kernel
adaptor = spl.initial_adaptor

if init_params === nothing
init_params = randn(size(metric, 1))
end

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params)

# Compute next transition and state.
state = HMCState(0, t, h.metric, κ, adaptor)

# Take actual first step.
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
spl::HMCSampler,
state::HMCState;
nadapts::Int = 0,
kwargs...
)
# Get step size
@debug "current ϵ" getstepsize(spl, state)

# Compute transition.
i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)

# Make new transition.
t = transition(rng, h, κ, t_old.z)

# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, nadapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))

# Compute next transition and state.
newstate = HMCState(i, t, h.metric, κ, adaptor)

# Return `Transition` with additional stats added.
return Transition(t.z, tstat), newstate
end


################
### Callback ###
################
"""
HMCProgressCallback
A callback to be used with AbstractMCMC.jl's interface, replicating the
logging behavior of the non-AbstractMCMC [`sample`](@ref).
# Fields
$(FIELDS)
"""
struct HMCProgressCallback{P}
"`Progress` meter from ProgressMeters.jl."
pm::P
"Specifies whether or not to use display a progress bar."
progress::Bool
"If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation."
verbose::Bool
end

function HMCProgressCallback(n_samples; progress=true, verbose=false)
pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing
HMCProgressCallback(pm, progress, verbose)
end

function (cb::HMCProgressCallback)(
rng, model, spl, t, state, i;
nadapts = 0,
kwargs...
)
progress = cb.progress
verbose = cb.verbose
pm = cb.pm

metric = state.metric
adaptor = state.adaptor
κ = state.κ
tstat = t.stat
isadapted = tstat.is_adapt

# Update progress meter
if progress
# Do include current iteration and mass matrix
pm_next!(
pm,
(iterations=i, tstat..., mass_matrix=metric)
)
# Report finish of adapation
elseif verbose && isadapted && i == nadapts
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric
end
end
4 changes: 2 additions & 2 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat}, V<:DualValue}
@warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ))
# NOTE eltype has to be inlined to avoid type stability issue; see #267
ℓπ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
ℓπ.gradient
)
ℓκ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value),
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value),
ℓκ.gradient
)
end
Expand Down
2 changes: 0 additions & 2 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ end
##
## Interface functions
##

function sample_init(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
Expand Down Expand Up @@ -143,7 +142,6 @@ sample(
verbose::Bool=true,
progress::Bool=false
)
Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`.
- The randomness is controlled by `rng`.
- If `rng` is not provided, `GLOBAL_RNG` will be used.
Expand Down
2 changes: 1 addition & 1 deletion src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedIntegrationTime} =
## Kernel interface
##

struct HMCKernel{R, T<:Trajectory} <: AbstractMCMCKernel
struct HMCKernel{R, T<:Trajectory} <: AbstractMCMCKernel
refreshment::R
τ::T
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand Down
Loading

2 comments on commit 7cad9f0

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/40946

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 7cad9f02e6eb4095ad8b2e7ef06d3474dfb19135
git push origin v0.3.0

Please sign in to comment.