Skip to content

Commit

Permalink
feat: port generalized leapfrog
Browse files Browse the repository at this point in the history
Signed-off-by: Kai Xu <[email protected]>
  • Loading branch information
xukai92 committed Jul 10, 2024
1 parent ecc388a commit 2cb2d9e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ export Hamiltonian

include("integrator.jl")
export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("trajectory.jl")
export Trajectory,
Expand Down
93 changes: 93 additions & 0 deletions src/riemannian/integrator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
$(TYPEDEF)
Generalized leapfrog integrator with fixed step size `ϵ`.
# Fields
$(TYPEDFIELDS)
"""
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
n::Int
end
Base.show(io::IO, l::GeneralizedLeapfrog) =
print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))")

# fallback to ignore return_cache & cache kwargs for other ∂H∂θ
function ∂H∂θ_cache(h, θ, r; return_cache = false, cache = nothing)
dv = ∂H∂θ(h, θ, r)
return return_cache ? (dv, nothing) : dv
end

# TODO(Kai) make sure vectorization works
# TODO(Kai) check if tempering is valid
# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl`
function step(
lf::GeneralizedLeapfrog{T},
h::Hamiltonian,
z::P,
n_steps::Int = 1;
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj} = Val(false),
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'

res = if FullTraj
Vector{P}(undef, n_steps)
else
z
end

for i = 1:n_steps
θ_init, r_init = z.θ, z.r
# Tempering
#r = temper(lf, r, (i=i, is_half=true), n_steps)
# eq (16) of Girolami & Calderhead (2011)
r_half = copy(r_init)
local cache
for j = 1:lf.n
# Reuse cache for the first iteration
if j == 1
@unpack value, gradient = z.ℓπ
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache = true)
@unpack value, gradient = retval
else # reuse cache
@unpack value, gradient = ∂H∂θ_cache(h, θ_init, r_half; cache = cache)
end
r_half = r_init - ϵ / 2 * gradient
end
# eq (17) of Girolami & Calderhead (2011)
θ_full = copy(θ_init)
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
for j = 1:lf.n
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
end
# eq (18) of Girolami & Calderhead (2011)
@unpack value, gradient = ∂H∂θ(h, θ_full, r_half)
r_full = r_half - ϵ / 2 * gradient
# Tempering
#r = temper(lf, r, (i=i, is_half=false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ_full, r_full; ℓπ = DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
end
return res
end

0 comments on commit 2cb2d9e

Please sign in to comment.