From 2cb2d9e770697de1a3a12a8bd58c477068dc5107 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Jul 2024 15:07:16 -0400 Subject: [PATCH] feat: port generalized leapfrog Signed-off-by: Kai Xu --- src/AdvancedHMC.jl | 2 + src/riemannian/integrator.jl | 93 ++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 src/riemannian/integrator.jl diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 42d52767..fcaab095 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -52,6 +52,8 @@ export Hamiltonian include("integrator.jl") export Leapfrog, JitteredLeapfrog, TemperedLeapfrog +include("riemannian/integrator.jl") +export GeneralizedLeapfrog include("trajectory.jl") export Trajectory, diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl new file mode 100644 index 00000000..718193e8 --- /dev/null +++ b/src/riemannian/integrator.jl @@ -0,0 +1,93 @@ +""" +$(TYPEDEF) + +Generalized leapfrog integrator with fixed step size `ϵ`. + +# Fields + +$(TYPEDFIELDS) +""" +struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} + "Step size." + ϵ::T + n::Int +end +Base.show(io::IO, l::GeneralizedLeapfrog) = + print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") + +# fallback to ignore return_cache & cache kwargs for other ∂H∂θ +function ∂H∂θ_cache(h, θ, r; return_cache = false, cache = nothing) + dv = ∂H∂θ(h, θ, r) + return return_cache ? (dv, nothing) : dv +end + +# TODO(Kai) make sure vectorization works +# TODO(Kai) check if tempering is valid +# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` +function step( + lf::GeneralizedLeapfrog{T}, + h::Hamiltonian, + z::P, + n_steps::Int = 1; + fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0 + full_trajectory::Val{FullTraj} = Val(false), +) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} + n_steps = abs(n_steps) # to support `n_steps < 0` cases + + ϵ = fwd ? step_size(lf) : -step_size(lf) + ϵ = ϵ' + + res = if FullTraj + Vector{P}(undef, n_steps) + else + z + end + + for i = 1:n_steps + θ_init, r_init = z.θ, z.r + # Tempering + #r = temper(lf, r, (i=i, is_half=true), n_steps) + # eq (16) of Girolami & Calderhead (2011) + r_half = copy(r_init) + local cache + for j = 1:lf.n + # Reuse cache for the first iteration + if j == 1 + @unpack value, gradient = z.ℓπ + elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) + retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache = true) + @unpack value, gradient = retval + else # reuse cache + @unpack value, gradient = ∂H∂θ_cache(h, θ_init, r_half; cache = cache) + end + r_half = r_init - ϵ / 2 * gradient + end + # eq (17) of Girolami & Calderhead (2011) + θ_full = copy(θ_init) + term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop + for j = 1:lf.n + θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) + end + # eq (18) of Girolami & Calderhead (2011) + @unpack value, gradient = ∂H∂θ(h, θ_full, r_half) + r_full = r_half - ϵ / 2 * gradient + # Tempering + #r = temper(lf, r, (i=i, is_half=false), n_steps) + # Create a new phase point by caching the logdensity and gradient + z = phasepoint(h, θ_full, r_full; ℓπ = DualValue(value, gradient)) + # Update result + if FullTraj + res[i] = z + else + res = z + end + if !isfinite(z) + # Remove undef + if FullTraj + res = res[isassigned.(Ref(res), 1:n_steps)] + end + break + end + end + return res +end \ No newline at end of file