From 7d37e272fde91624c45a3de548b418fbfef6f490 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 4 Jul 2024 11:04:25 -0400 Subject: [PATCH 01/13] feat: support position-dependent kinetic Signed-off-by: Kai Xu --- src/hamiltonian.jl | 11 ++++++++--- src/metric.jl | 8 +++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index f2bbd230..da0be058 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -45,6 +45,11 @@ end ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = h.metric.M⁻¹ * r +# TODO make the order of θ and r consistent with neg_energy +# TODO add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic +∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) +∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) + struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue} θ::T # Position variables / model parameters. r::T # Momentum variables @@ -156,7 +161,7 @@ phasepoint( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, θ::AbstractVecOrMat{T}, h::Hamiltonian, -) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic)) +) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic, θ)) abstract type AbstractMomentumRefreshment end @@ -168,7 +173,7 @@ refresh( ::FullMomentumRefreshment, h::Hamiltonian, z::PhasePoint, -) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic)) +) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ)) """ $(TYPEDEF) @@ -196,4 +201,4 @@ refresh( ref::PartialMomentumRefreshment, h::Hamiltonian, z::PhasePoint, -) = phasepoint(h, z.θ, ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic)) +) = phasepoint(h, z.θ, ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ)) diff --git a/src/metric.jl b/src/metric.jl index f4585b62..a967e331 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -129,13 +129,15 @@ function _rand( return r end -# TODO The rand interface should be updated by rand from momentum distribution + optional affine transformation by metric -Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) = +# TODO The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric" +# ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent) +Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, _θ) = _rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand( rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic::AbstractKinetic, + _θ, ) = _rand(rng, metric, kinetic) -Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) = +Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, _θ) = rand(GLOBAL_RNG, metric, kinetic) From 1c06c3a98fb1e320343de7b147e1d069ada4634b Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 5 Jul 2024 15:17:29 +0100 Subject: [PATCH 02/13] Update src/hamiltonian.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/hamiltonian.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index da0be058..21e62bc4 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -201,4 +201,8 @@ refresh( ref::PartialMomentumRefreshment, h::Hamiltonian, z::PhasePoint, -) = phasepoint(h, z.θ, ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ)) +) = phasepoint( + h, + z.θ, + ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ), +) From d0087a1f3568a7b1cbaf918fbd6cb86f4017491f Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 5 Jul 2024 16:15:44 -0400 Subject: [PATCH 03/13] Update src/hamiltonian.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/hamiltonian.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index 21e62bc4..8c7a8ec6 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -45,8 +45,8 @@ end ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = h.metric.M⁻¹ * r -# TODO make the order of θ and r consistent with neg_energy -# TODO add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic +# TODO (kai) make the order of θ and r consistent with neg_energy +# TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic ∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) ∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) From eb50e04bc39e5b3d28dd2c2c638b360e70386968 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 5 Jul 2024 16:16:01 -0400 Subject: [PATCH 04/13] Update src/hamiltonian.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/hamiltonian.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index 8c7a8ec6..b5546051 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -47,6 +47,7 @@ end # TODO (kai) make the order of θ and r consistent with neg_energy # TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic +# The gradient of a position-dependent Hamiltonian system depends on both θ and r. ∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) ∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) From a4ade2954380af401215247350f539d0a623f05e Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 5 Jul 2024 16:16:44 -0400 Subject: [PATCH 05/13] Update src/metric.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index a967e331..dce1819b 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -131,7 +131,7 @@ end # TODO The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric" # ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent) -Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, _θ) = +Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, θ) = _rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand( rng::AbstractVector{<:AbstractRNG}, From 1d75675784cef5da9e22df020a6955dd4bbdd8ba Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 5 Jul 2024 16:16:53 -0400 Subject: [PATCH 06/13] Update src/metric.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index dce1819b..a43cc0bb 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -137,7 +137,7 @@ Base.rand( rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic::AbstractKinetic, - _θ, + θ, ) = _rand(rng, metric, kinetic) Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, _θ) = rand(GLOBAL_RNG, metric, kinetic) From f4976db15d6eb362d8fed83de8bb5027620aac41 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 5 Jul 2024 16:16:58 -0400 Subject: [PATCH 07/13] Update src/metric.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index a43cc0bb..e80515b8 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -139,5 +139,5 @@ Base.rand( kinetic::AbstractKinetic, θ, ) = _rand(rng, metric, kinetic) -Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, _θ) = +Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ) = rand(GLOBAL_RNG, metric, kinetic) From e36258b2b14266999e4d5f89c8d06b9f251b8eaa Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 8 Jul 2024 10:26:19 -0400 Subject: [PATCH 08/13] fix: add position-independent methods back for leapfrog comptaibility Signed-off-by: Kai Xu --- src/metric.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/metric.jl b/src/metric.jl index e80515b8..f579fb2b 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -130,14 +130,24 @@ function _rand( end # TODO The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric" +Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) = + _rand(rng, metric, kinetic) # this disambiguity is required by Random.rand +Base.rand( + rng::AbstractVector{<:AbstractRNG}, + metric::AbstractMetric, + kinetic::AbstractKinetic, +) = _rand(rng, metric, kinetic) +Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) = + rand(GLOBAL_RNG, metric, kinetic) + # ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent) Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, θ) = - _rand(rng, metric, kinetic) # this disambiguity is required by Random.rand + rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand( rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic::AbstractKinetic, θ, -) = _rand(rng, metric, kinetic) +) = rand(rng, metric, kinetic) Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ) = - rand(GLOBAL_RNG, metric, kinetic) + rand(metric, kinetic) From ad2ad2ca5be6054939807e362f1af38bfaeaba67 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:17:57 +0100 Subject: [PATCH 09/13] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/metric.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/metric.jl b/src/metric.jl index f579fb2b..b6caee59 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -141,7 +141,7 @@ Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) = rand(GLOBAL_RNG, metric, kinetic) # ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent) -Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, θ) = +Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand( rng::AbstractVector{<:AbstractRNG}, @@ -149,5 +149,4 @@ Base.rand( kinetic::AbstractKinetic, θ, ) = rand(rng, metric, kinetic) -Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ) = - rand(metric, kinetic) +Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = rand(metric, kinetic) From 920b40979dc94f5f002154250582e938ec9bb5d3 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:18:32 +0100 Subject: [PATCH 10/13] Update src/metric.jl --- src/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index b6caee59..887f5d06 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -147,6 +147,6 @@ Base.rand( rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic::AbstractKinetic, - θ, + θ::AbstractVecOrMat, ) = rand(rng, metric, kinetic) Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = rand(metric, kinetic) From 73ce7313c8ff2314982e6d08ba18f9f7d8f53915 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:20:01 +0100 Subject: [PATCH 11/13] Update src/metric.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/metric.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index 887f5d06..8963e44e 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -149,4 +149,5 @@ Base.rand( kinetic::AbstractKinetic, θ::AbstractVecOrMat, ) = rand(rng, metric, kinetic) -Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = rand(metric, kinetic) +Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = + rand(metric, kinetic) From 3bc7efe48f03a292505a9efc56521352e62ca82c Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:20:07 +0100 Subject: [PATCH 12/13] Update src/metric.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/metric.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/metric.jl b/src/metric.jl index 8963e44e..8521eec9 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -141,8 +141,12 @@ Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) = rand(GLOBAL_RNG, metric, kinetic) # ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent) -Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) = - rand(rng, metric, kinetic) # this disambiguity is required by Random.rand +Base.rand( + rng::AbstractRNG, + metric::AbstractMetric, + kinetic::AbstractKinetic, + θ::AbstractVecOrMat, +) = rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand( rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, From 4f7d44eb9b0066a11fc48a9e1adfaae15d84d7ae Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Jul 2024 08:53:41 -0400 Subject: [PATCH 13/13] Update src/metric.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index 8521eec9..2afbd629 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -129,7 +129,7 @@ function _rand( return r end -# TODO The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric" +# TODO (kai) The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric" Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) = _rand(rng, metric, kinetic) # this disambiguity is required by Random.rand Base.rand(