Skip to content

Commit

Permalink
feat: support position-dependent kinetic (#369)
Browse files Browse the repository at this point in the history
* feat: support position-dependent kinetic

Signed-off-by: Kai Xu <[email protected]>

* Update src/hamiltonian.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/hamiltonian.jl

Co-authored-by: Hong Ge <[email protected]>

* Update src/hamiltonian.jl

Co-authored-by: Hong Ge <[email protected]>

* Update src/metric.jl

Co-authored-by: Hong Ge <[email protected]>

* Update src/metric.jl

Co-authored-by: Hong Ge <[email protected]>

* Update src/metric.jl

Co-authored-by: Hong Ge <[email protected]>

* fix: add position-independent methods back for leapfrog comptaibility

Signed-off-by: Kai Xu <[email protected]>

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/metric.jl

* Update src/metric.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/metric.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/metric.jl

Co-authored-by: Hong Ge <[email protected]>

---------

Signed-off-by: Kai Xu <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 10, 2024
1 parent 6111133 commit ecc388a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
16 changes: 13 additions & 3 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ end
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) =
h.metric.M⁻¹ * r

# TODO (kai) make the order of θ and r consistent with neg_energy
# TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic
# The gradient of a position-dependent Hamiltonian system depends on both θ and r.
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)

struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue}
θ::T # Position variables / model parameters.
r::T # Momentum variables
Expand Down Expand Up @@ -156,7 +162,7 @@ phasepoint(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
θ::AbstractVecOrMat{T},
h::Hamiltonian,
) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic))
) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic, θ))

abstract type AbstractMomentumRefreshment end

Expand All @@ -168,7 +174,7 @@ refresh(
::FullMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic))
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ))

"""
$(TYPEDEF)
Expand Down Expand Up @@ -196,4 +202,8 @@ refresh(
ref::PartialMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
) = phasepoint(h, z.θ, ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic))
) = phasepoint(
h,
z.θ,
ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ),
)
18 changes: 17 additions & 1 deletion src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function _rand(
return r
end

# TODO The rand interface should be updated by rand from momentum distribution + optional affine transformation by metric
# TODO (kai) The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric"
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) =
_rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
Base.rand(
Expand All @@ -139,3 +139,19 @@ Base.rand(
) = _rand(rng, metric, kinetic)
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) =
rand(GLOBAL_RNG, metric, kinetic)

# ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent)
Base.rand(
rng::AbstractRNG,
metric::AbstractMetric,
kinetic::AbstractKinetic,
θ::AbstractVecOrMat,
) = rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
Base.rand(
rng::AbstractVector{<:AbstractRNG},
metric::AbstractMetric,
kinetic::AbstractKinetic,
θ::AbstractVecOrMat,
) = rand(rng, metric, kinetic)
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) =
rand(metric, kinetic)

0 comments on commit ecc388a

Please sign in to comment.