Skip to content

Commit

Permalink
feat: port relativistic hmc from research repo
Browse files Browse the repository at this point in the history
Signed-off-by: Kai Xu <[email protected]>
  • Loading branch information
xukai92 committed Jul 25, 2024
1 parent 2b3814c commit 7fc04d2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 87 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.6.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdaptiveRejectionSampling = "c75e803d-635f-53bd-ab7d-544e482d8c75"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
Expand Down Expand Up @@ -31,23 +32,24 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[compat]
AbstractMCMC = "4.2, 5"
AdaptiveRejectionSampling = "0.1.1"
ArgCheck = "1, 2"
CUDA = "3, 4, 5"
DocStringExtensions = "0.8, 0.9"
InplaceOps = "0.3"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
MCMCChains = "5, 6"
OrdinaryDiffEq = "6"
ProgressMeter = "1"
Random = "1.6"
Requires = "0.5, 1"
Setfield = "0.7, 0.8, 1"
SimpleUnPack = "1.1"
Statistics = "1.6"
StatsBase = "0.31, 0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
LinearAlgebra = "1.6"
Random = "1.6"
julia = "1.6"

[extras]
Expand Down
83 changes: 0 additions & 83 deletions research/src/relativistic_hmc.jl

This file was deleted.

23 changes: 21 additions & 2 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ export Hamiltonian

include("integrator.jl")
export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("trajectory.jl")
export Trajectory,
Expand Down Expand Up @@ -128,6 +126,27 @@ export sample
include("constructors.jl")
export HMCSampler, HMC, NUTS, HMCDA

module Experimental
using Random, Statistics, LinearAlgebra
using ..AdvancedHMC

import ..AdvancedHMC: ∂H∂r, neg_energy, AbstractKinetic
import Random: AbstractRNG
include("relativistic/hamiltonian.jl")
export RelativisticKinetic, DimensionwiseRelativisticKinetic

using AdaptiveRejectionSampling: RejectionSampler, run_sampler!
import ..AdvancedHMC: _rand
include("relativistic/metric.jl")

using ..AdvancedHMC: @unpack, TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step, step_size
import ..AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step
include("riemannian/integrator.jl")
include("riemannian/hamiltonian.jl")
include("riemannian/metric.jl")
export GeneralizedLeapfrog
end

include("abstractmcmc.jl")

## Without explicit AD backend
Expand Down

0 comments on commit 7fc04d2

Please sign in to comment.