diff --git a/Project.toml b/Project.toml index ae21f7286..2e1b5e3d3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,7 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.27.1" +version = "0.28" + [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -48,6 +49,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" +DynamicHMC = "3.4" DynamicPPL = "0.23" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" @@ -68,11 +70,14 @@ StatsFuns = "0.8, 0.9, 1" Tracker = "0.2.3" julia = "1.7" -[weakdeps] -Optim = "429524aa-4258-5aef-a3af-852621145aeb" - [extensions] +TuringDynamicHMCExt = "DynamicHMC" TuringOptimExt = "Optim" [extras] +DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[weakdeps] +DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" diff --git a/src/contrib/inference/dynamichmc.jl b/ext/TuringDynamicHMCExt.jl similarity index 61% rename from src/contrib/inference/dynamichmc.jl rename to ext/TuringDynamicHMCExt.jl index 8324cd1a0..6c97d7949 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -1,7 +1,21 @@ +module TuringDynamicHMCExt ### ### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl ### + +if isdefined(Base, :get_extension) + import DynamicHMC + using Turing + using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL + using Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS +else + import ..DynamicHMC + using ..Turing + using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL + using ..Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS +end + """ DynamicNUTS @@ -12,10 +26,15 @@ To use it, make sure you have DynamicHMC package (version >= 2) loaded: using DynamicHMC ``` """ -struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end +struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian{AD} + sampler::T +end -DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...) -DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}() +DynamicNUTS(args...) = DynamicNUTS{Turing.ADBackend()}(args...) +DynamicNUTS{AD}(spl::DynamicHMC.NUTS, space::Tuple) where AD = DynamicNUTS{AD, space, typeof(spl)}(spl) +DynamicNUTS{AD}(spl::DynamicHMC.NUTS) where AD = DynamicNUTS{AD}(spl, ()) +DynamicNUTS{AD}() where AD = DynamicNUTS{AD}(DynamicHMC.NUTS()) +Turing.externalsampler(spl::DynamicHMC.NUTS) = DynamicNUTS(spl) DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space @@ -27,7 +46,7 @@ State of the [`DynamicNUTS`](@ref) sampler. # Fields $(TYPEDFIELDS) """ -struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S} +struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} logdensity::L vi::V "Cache of sample, log density, and gradient of log density evaluation." @@ -36,26 +55,13 @@ struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S} stepsize::S end -# Implement interface of `Gibbs` sampler -function gibbs_state( - model::Model, - spl::Sampler{<:DynamicNUTS}, - state::DynamicNUTSState, - varinfo::AbstractVarInfo, -) - # Update the log density function and its cached evaluation. - ℓ = LogDensityProblemsAD.ADgradient(Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())) - Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl]) - return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize) -end - -DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform() +DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS}) = DynamicPPL.SampleFromUniform() function DynamicPPL.initialstep( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:DynamicNUTS}, - vi::AbstractVarInfo; + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:DynamicNUTS}, + vi::DynamicPPL.AbstractVarInfo; kwargs... ) # Ensure that initial sample is in unconstrained space. @@ -83,16 +89,16 @@ function DynamicPPL.initialstep( vi = DynamicPPL.setlogp!!(vi, Q.ℓq) # Create first sample and state. - sample = Transition(vi) + sample = Turing.Inference.Transition(vi) state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) return sample, state end function AbstractMCMC.step( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:DynamicNUTS}, + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:DynamicNUTS}, state::DynamicNUTSState; kwargs... ) @@ -101,7 +107,7 @@ function AbstractMCMC.step( ℓ = state.logdensity steps = DynamicHMC.mcmc_steps( rng, - DynamicHMC.NUTS(), + spl.alg.sampler, state.metric, ℓ, state.stepsize, @@ -113,8 +119,10 @@ function AbstractMCMC.step( vi = DynamicPPL.setlogp!!(vi, Q.ℓq) # Create next sample and state. - sample = Transition(vi) + sample = Turing.Inference.Transition(vi) newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) return sample, newstate end + +end \ No newline at end of file diff --git a/src/Turing.jl b/src/Turing.jl index 23bfc16b8..11dcbdb6f 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -135,21 +135,15 @@ export @model, # modelling optim_function, optim_problem +if !isdefined(Base, :get_extension) + using Requires +end + function __init__() @static if !isdefined(Base, :get_extension) @require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("../ext/TuringOptimExt.jl") - end - @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin - @eval Inference begin - import ..DynamicHMC - - if isdefined(DynamicHMC, :mcmc_with_warmup) - include("contrib/inference/dynamichmc.jl") - else - error("Please update DynamicHMC, v1.x is no longer supported") - end - end - end + @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" include("../ext/TuringDynamicHMCExt.jl") + end end end diff --git a/test/contrib/inference/dynamichmc.jl b/test/contrib/inference/dynamichmc.jl index f5a3ce377..61027a196 100644 --- a/test/contrib/inference/dynamichmc.jl +++ b/test/contrib/inference/dynamichmc.jl @@ -1,15 +1,10 @@ -@stage_testset "dynamichmc" "dynamichmc.jl" begin +@testset "TuringDynamicHMCExt" begin import DynamicHMC Random.seed!(100) - @test DynamicPPL.alg_str(Sampler(DynamicNUTS(), gdemo_default)) == "DynamicNUTS" + @test DynamicPPL.alg_str(Sampler(externalsampler(DynamicHMC.NUTS()))) == "DynamicNUTS" - chn = sample(gdemo_default, DynamicNUTS(), 10_000) + spl = externalsampler(DynamicHMC.NUTS()) + chn = sample(gdemo_default, spl, 10_000) check_gdemo(chn) - - chn2 = sample(gdemo_default, Gibbs(PG(15, :s), DynamicNUTS(:m)), 10_000) - check_gdemo(chn2) - - chn3 = sample(gdemo_default, Gibbs(DynamicNUTS(:s), ESS(:m)), 10_000) - check_gdemo(chn3) end