diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index f2bbd230..b5546051 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -45,6 +45,12 @@ end ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = h.metric.M⁻¹ * r +# TODO (kai) make the order of θ and r consistent with neg_energy +# TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic +# The gradient of a position-dependent Hamiltonian system depends on both θ and r. +∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) +∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) + struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue} θ::T # Position variables / model parameters. r::T # Momentum variables @@ -156,7 +162,7 @@ phasepoint( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, θ::AbstractVecOrMat{T}, h::Hamiltonian, -) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic)) +) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic, θ)) abstract type AbstractMomentumRefreshment end @@ -168,7 +174,7 @@ refresh( ::FullMomentumRefreshment, h::Hamiltonian, z::PhasePoint, -) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic)) +) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ)) """ $(TYPEDEF) @@ -196,4 +202,8 @@ refresh( ref::PartialMomentumRefreshment, h::Hamiltonian, z::PhasePoint, -) = phasepoint(h, z.θ, ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic)) +) = phasepoint( + h, + z.θ, + ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ), +) diff --git a/src/metric.jl b/src/metric.jl index f4585b62..2afbd629 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -129,7 +129,7 @@ function _rand( return r end -# TODO The rand interface should be updated by rand from momentum distribution + optional affine transformation by metric +# TODO (kai) The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric" Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) = _rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand( @@ -139,3 +139,19 @@ Base.rand( ) = _rand(rng, metric, kinetic) Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) = rand(GLOBAL_RNG, metric, kinetic) + +# ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent) +Base.rand( + rng::AbstractRNG, + metric::AbstractMetric, + kinetic::AbstractKinetic, + θ::AbstractVecOrMat, +) = rand(rng, metric, kinetic) # this disambiguity is required by Random.rand +Base.rand( + rng::AbstractVector{<:AbstractRNG}, + metric::AbstractMetric, + kinetic::AbstractKinetic, + θ::AbstractVecOrMat, +) = rand(rng, metric, kinetic) +Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = + rand(metric, kinetic)