Skip to content

Commit

Permalink
DynamicHMC (#2045)
Browse files Browse the repository at this point in the history
* first draft

* bug

* DynamicNUTS

* working

* bring back all tests

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

* extension

* extension

* Ext in name + __init__

* different name

* end module

* compiles but still doesnt find the methods

* Compat + no __init__

* working

* better imports inside Ext

* Apply suggestions from code review

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: David Widmann <[email protected]>

* DocStringExtensions

* DocExtensions

* bring back previous interface

* bring back previous interface

* bring back previous interface --> working

* missing end

* remove gibbs support

* Update Project.toml

Co-authored-by: David Widmann <[email protected]>

* Update test/contrib/inference/dynamichmc.jl

Co-authored-by: David Widmann <[email protected]>

* rename

* type for DynamicNUTS

* rename

* Parametric sampler

* bug

* Apply suggestions from code review

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* fixing tests

* Update Project.toml

---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
4 people authored Jul 29, 2023
1 parent fa3a6c1 commit 4d41959
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 53 deletions.
13 changes: 9 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.27.1"
version = "0.28"


[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -48,6 +49,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.23"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Expand All @@ -68,11 +70,14 @@ StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
julia = "1.7"

[weakdeps]
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[extensions]
TuringDynamicHMCExt = "DynamicHMC"
TuringOptimExt = "Optim"

[extras]
DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[weakdeps]
DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
64 changes: 36 additions & 28 deletions src/contrib/inference/dynamichmc.jl → ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
module TuringDynamicHMCExt
###
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
###


if isdefined(Base, :get_extension)
import DynamicHMC
using Turing
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS
else
import ..DynamicHMC
using ..Turing
using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using ..Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS
end

"""
DynamicNUTS
Expand All @@ -12,10 +26,15 @@ To use it, make sure you have DynamicHMC package (version >= 2) loaded:
using DynamicHMC
```
"""
struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end
struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian{AD}
sampler::T
end

DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()
DynamicNUTS(args...) = DynamicNUTS{Turing.ADBackend()}(args...)
DynamicNUTS{AD}(spl::DynamicHMC.NUTS, space::Tuple) where AD = DynamicNUTS{AD, space, typeof(spl)}(spl)
DynamicNUTS{AD}(spl::DynamicHMC.NUTS) where AD = DynamicNUTS{AD}(spl, ())
DynamicNUTS{AD}() where AD = DynamicNUTS{AD}(DynamicHMC.NUTS())
Turing.externalsampler(spl::DynamicHMC.NUTS) = DynamicNUTS(spl)

DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space

Expand All @@ -27,7 +46,7 @@ State of the [`DynamicNUTS`](@ref) sampler.
# Fields
$(TYPEDFIELDS)
"""
struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S}
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
logdensity::L
vi::V
"Cache of sample, log density, and gradient of log density evaluation."
Expand All @@ -36,26 +55,13 @@ struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S}
stepsize::S
end

# Implement interface of `Gibbs` sampler
function gibbs_state(
model::Model,
spl::Sampler{<:DynamicNUTS},
state::DynamicNUTSState,
varinfo::AbstractVarInfo,
)
# Update the log density function and its cached evaluation.
= LogDensityProblemsAD.ADgradient(Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext()))
Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl])
return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize)
end

DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS}) = DynamicPPL.SampleFromUniform()

function DynamicPPL.initialstep(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:DynamicNUTS},
vi::AbstractVarInfo;
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
vi::DynamicPPL.AbstractVarInfo;
kwargs...
)
# Ensure that initial sample is in unconstrained space.
Expand Down Expand Up @@ -83,16 +89,16 @@ function DynamicPPL.initialstep(
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create first sample and state.
sample = Transition(vi)
sample = Turing.Inference.Transition(vi)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)

return sample, state
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:DynamicNUTS},
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
state::DynamicNUTSState;
kwargs...
)
Expand All @@ -101,7 +107,7 @@ function AbstractMCMC.step(
= state.logdensity
steps = DynamicHMC.mcmc_steps(
rng,
DynamicHMC.NUTS(),
spl.alg.sampler,
state.metric,
ℓ,
state.stepsize,
Expand All @@ -113,8 +119,10 @@ function AbstractMCMC.step(
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create next sample and state.
sample = Transition(vi)
sample = Turing.Inference.Transition(vi)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)

return sample, newstate
end

end
18 changes: 6 additions & 12 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,15 @@ export @model, # modelling
optim_function,
optim_problem

if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
@require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("../ext/TuringOptimExt.jl")
end
@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
@eval Inference begin
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
end
end
end
@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" include("../ext/TuringDynamicHMCExt.jl")
end
end

end
13 changes: 4 additions & 9 deletions test/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
@stage_testset "dynamichmc" "dynamichmc.jl" begin
@testset "TuringDynamicHMCExt" begin
import DynamicHMC
Random.seed!(100)

@test DynamicPPL.alg_str(Sampler(DynamicNUTS(), gdemo_default)) == "DynamicNUTS"
@test DynamicPPL.alg_str(Sampler(externalsampler(DynamicHMC.NUTS()))) == "DynamicNUTS"

chn = sample(gdemo_default, DynamicNUTS(), 10_000)
spl = externalsampler(DynamicHMC.NUTS())
chn = sample(gdemo_default, spl, 10_000)
check_gdemo(chn)

chn2 = sample(gdemo_default, Gibbs(PG(15, :s), DynamicNUTS(:m)), 10_000)
check_gdemo(chn2)

chn3 = sample(gdemo_default, Gibbs(DynamicNUTS(:s), ESS(:m)), 10_000)
check_gdemo(chn3)
end

2 comments on commit 4d41959

@yebai
Copy link
Member

@yebai yebai commented on 4d41959 Jul 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88639

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.28.0 -m "<description of version>" 4d41959fb95b6f9417cb1474c11dd4d66298f963
git push origin v0.28.0

Please sign in to comment.