Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized leapfrog integrator #370

Merged
merged 6 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ export Hamiltonian

include("integrator.jl")
export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("trajectory.jl")
export Trajectory,
Expand Down
93 changes: 93 additions & 0 deletions src/riemannian/integrator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
$(TYPEDEF)

Generalized leapfrog integrator with fixed step size `ϵ`.

# Fields

$(TYPEDFIELDS)
"""
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
n::Int
end
Base.show(io::IO, l::GeneralizedLeapfrog) =
print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))")

# fallback to ignore return_cache & cache kwargs for other ∂H∂θ
function ∂H∂θ_cache(h, θ, r; return_cache = false, cache = nothing)
dv = ∂H∂θ(h, θ, r)
return return_cache ? (dv, nothing) : dv
end

# TODO(Kai) make sure vectorization works
Copy link
Member

@yebai yebai Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we throw an error when users try vectorisation and tempering with GeneralisedLeapFrog?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i added an error handling for vectorization.
there is no way to do tempering until we define TemperedGeneralizedLeapfrog.

# TODO(Kai) check if tempering is valid
# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl`
function step(
lf::GeneralizedLeapfrog{T},
h::Hamiltonian,
z::P,
n_steps::Int = 1;
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj} = Val(false),
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'
xukai92 marked this conversation as resolved.
Show resolved Hide resolved

res = if FullTraj
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
Vector{P}(undef, n_steps)
else
z
end

for i = 1:n_steps
θ_init, r_init = z.θ, z.r
# Tempering
#r = temper(lf, r, (i=i, is_half=true), n_steps)
# eq (16) of Girolami & Calderhead (2011)
r_half = r_init
local cache
for j = 1:lf.n
# Reuse cache for the first iteration
if j == 1
@unpack value, gradient = z.ℓπ
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache = true)
@unpack value, gradient = retval
else # reuse cache
@unpack value, gradient = ∂H∂θ_cache(h, θ_init, r_half; cache = cache)
end
r_half = r_init - ϵ / 2 * gradient
end
# eq (17) of Girolami & Calderhead (2011)
θ_full = θ_init
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
for j = 1:lf.n
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
end
# eq (18) of Girolami & Calderhead (2011)
@unpack value, gradient = ∂H∂θ(h, θ_full, r_half)
r_full = r_half - ϵ / 2 * gradient
# Tempering
#r = temper(lf, r, (i=i, is_half=false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ_full, r_full; ℓπ = DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
end
return res
end
Loading