diff --git a/src/Turing.jl b/src/Turing.jl index 24f53e7f9..8e28ffe7e 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -24,30 +24,31 @@ using Markdown using Libtask using MacroTools -function __init__() - @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin - using CmdStan - import CmdStan: Adapt, Hmc - end +@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin + using CmdStan + import CmdStan: Adapt, Hmc +end - @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval begin - using DynamicHMC, LogDensityProblems - using LogDensityProblems: AbstractLogDensityProblem, ValueGradient +@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval begin + using .DynamicHMC: NUTS_init_tune_mcmc +end - struct FunctionLogDensity{F} <: AbstractLogDensityProblem - dimension::Int - f::F - end +@init @require LogDensityProblems="6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @eval begin + using .LogDensityProblems: AbstractLogDensityProblem, ValueGradient + struct FunctionLogDensity{F} <: AbstractLogDensityProblem + dimension::Int + f::F + end - LogDensityProblems.dimension(ℓ::FunctionLogDensity) = ℓ.dimension + LogDensityProblems.dimension(ℓ::FunctionLogDensity) = ℓ.dimension - LogDensityProblems.logdensity(::Type{ValueGradient}, ℓ::FunctionLogDensity, x) = ℓ.f(x)::ValueGradient - end + LogDensityProblems.logdensity(::Type{ValueGradient}, ℓ::FunctionLogDensity, x) = ℓ.f(x)::ValueGradient end + import Base: ~, convert, promote_rule, rand, getindex, setindex! import Distributions: sample -import ForwardDiff: gradient -using Flux.Tracker +using ForwardDiff: ForwardDiff +using Flux.Tracker: Tracker import MCMCChain: AbstractChains, Chains ############################## @@ -138,7 +139,7 @@ struct SampleFromPrior <: AbstractSampler end const AnySampler = Union{Nothing, AbstractSampler} include("utilities/resample.jl") -@static if isdefined(Turing, :CmdStan) +@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin include("utilities/stan-interface.jl") end include("utilities/helper.jl") diff --git a/src/core/ad.jl b/src/core/ad.jl index f0c4718e0..2b8da1731 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -1,5 +1,14 @@ abstract type ADBackend end struct ForwardDiffAD{chunk} <: ADBackend end +getchunksize(::T) where {T <: ForwardDiffAD} = getchunksize(T) +getchunksize(::Type{ForwardDiffAD{chunk}}) where chunk = chunk +getchunksize(::T) where {T <: Hamiltonian} = getchunksize(T) +getchunksize(::Type{<:Hamiltonian{AD}}) where AD = getchunksize(AD) +getchunksize(::T) where {T <: Sampler} = getchunksize(T) +getchunksize(::Type{<:Sampler{T}}) where {T <: Hamiltonian} = getchunksize(T) +getchunksize(::Nothing) = getchunksize(Nothing) +getchunksize(::Type{Nothing}) = CHUNKSIZE[] + struct FluxTrackerAD <: ADBackend end ADBackend() = ADBackend(ADBACKEND[]) @@ -12,14 +21,13 @@ function ADBackend(::Val{T}) where {T} end end -getchunksize(::Type{ForwardDiffAD{chunk}}) where chunk = Val(chunk) - - """ getADtype(alg) Finds the autodifferentiation type of the algorithm `alg`. """ +getADtype(::Nothing) = getADtype(Nothing) +getADtype(::Type{Nothing}) = getADtype() getADtype() = ADBackend() getADtype(s::Sampler) = getADtype(typeof(s)) getADtype(s::Type{<:Sampler{TAlg}}) where {TAlg} = getADtype(TAlg) @@ -37,30 +45,18 @@ gradient( Computes the gradient of the log joint of `θ` for the model specified by `(vi, sampler, model)` using whichever automatic differentation tool is currently active. """ -@generated function gradient( +function gradient( θ::AbstractVector{<:Real}, vi::VarInfo, model::Model, - sampler::TS=nothing, -) where {TS <: Union{Nothing, Sampler}} - if TS == Nothing - return quote - ad_type = getADtype() - if ad_type <: ForwardDiffAD - gradient_forward(θ, vi, model, sampler, getchunksize(ad_type)) - else - gradient_reverse(θ, vi, model, sampler) - end - end - else - ad_type = getADtype(TS) - @assert any(T -> ad_type <: T, (ForwardDiffAD, FluxTrackerAD)) - if ad_type <: ForwardDiffAD - chunk = getchunksize(ad_type) - return :(gradient_forward(θ, vi, model, sampler, $chunk)) - else ad_type <: FluxTrackerAD - return :(gradient_reverse(θ, vi, model, sampler)) - end + sampler::TS, +) where {TS <: Sampler{<:Hamiltonian}} + + ad_type = getADtype(TS) + if ad_type <: ForwardDiffAD + return gradient_forward(θ, vi, model, sampler) + else ad_type <: FluxTrackerAD + return gradient_reverse(θ, vi, model, sampler) end end @@ -70,7 +66,6 @@ gradient_forward( vi::VarInfo, model::Model, spl::Union{Nothing, Sampler}=nothing, - chunk_size::Int=CHUNKSIZE[], ) Computes the gradient of the log joint of `θ` for the model specified by `(vi, spl, model)` @@ -81,8 +76,7 @@ function gradient_forward( vi::VarInfo, model::Model, sampler::Union{Nothing, Sampler}=nothing, - ::Val{chunk_size}=Val(CHUNKSIZE[]), -) where chunk_size +) # Record old parameters. vals_old, logp_old = copy(vi.vals), copy(vi.logp) @@ -92,6 +86,7 @@ function gradient_forward( return -runmodel!(model, vi, sampler).logp end + chunk_size = getchunksize(sampler) # Set chunk size and do ForwardMode. chunk = ForwardDiff.Chunk(min(length(θ), chunk_size)) config = ForwardDiff.GradientConfig(f, θ, chunk) diff --git a/src/samplers/adapt/stan.jl b/src/samplers/adapt/stan.jl index d71657662..b18a0b97f 100644 --- a/src/samplers/adapt/stan.jl +++ b/src/samplers/adapt/stan.jl @@ -1,15 +1,15 @@ -@static if isdefined(Turing, :CmdStan) - function DualAveraging(spl::Sampler{<:AdaptiveHamiltonian}, adapt_conf::CmdStan.Adapt, ϵ::Real) - # Hyper parameters for dual averaging - γ = adapt_conf.gamma - t_0 = adapt_conf.t0 - κ = adapt_conf.kappa - δ = adapt_conf.delta - return DualAveraging(γ, t_0, κ, δ, DAState(ϵ)) - end +@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin + function DualAveraging(spl::Sampler{<:AdaptiveHamiltonian}, adapt_conf::CmdStan.Adapt, ϵ::Real) + # Hyper parameters for dual averaging + γ = adapt_conf.gamma + t_0 = adapt_conf.t0 + κ = adapt_conf.kappa + δ = adapt_conf.delta + return DualAveraging(γ, t_0, κ, δ, DAState(ϵ)) + end end -@static if isdefined(Turing, :CmdStan) +@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin function get_threephase_params(adapt_conf::CmdStan.Adapt) init_buffer = adapt_conf.init_buffer term_buffer = adapt_conf.term_buffer diff --git a/src/samplers/dynamichmc.jl b/src/samplers/dynamichmc.jl index 1b5e59027..c22da7d1f 100644 --- a/src/samplers/dynamichmc.jl +++ b/src/samplers/dynamichmc.jl @@ -1,4 +1,4 @@ -struct DynamicNUTS{T} <: Hamiltonian +struct DynamicNUTS{AD, T} <: Hamiltonian{AD} n_iters :: Integer # number of samples space :: Set{T} # sampling space, emtpy means all gid :: Integer # group ID @@ -27,9 +27,10 @@ end chn = sample(gdemo(1.5, 2.0), DynamicNUTS(2000)) ``` """ -function DynamicNUTS(n_iters::Integer, space...) +DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...) +function DynamicNUTS{AD}(n_iters::Integer, space...) where AD _space = isa(space, Symbol) ? Set([space]) : Set(space) - DynamicNUTS(n_iters, _space, 0) + DynamicNUTS{AD, eltype(_space)}(n_iters, _space, 0) end function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian @@ -37,15 +38,9 @@ function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian end function sample(model::Model, - alg::DynamicNUTS, - chunk_size=CHUNKSIZE[] - ) where T <: Hamiltonian + alg::DynamicNUTS{AD}, + ) where AD - if ADBACKEND[] == :forward_diff - default_chunk_size = CHUNKSIZE[] # record global chunk size - setchunksize(chunk_size) # set temp chunk size - end - spl = Sampler(alg) n = alg.n_iters @@ -75,9 +70,5 @@ function sample(model::Model, samples[i].value = Sample(vi, spl).value end - if ADBACKEND[] == :forward_diff - setchunksize(default_chunk_size) # revert global chunk size - end - return Chain(0, samples) end diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index 0349859dd..231d5b409 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -81,7 +81,7 @@ end DEFAULT_ADAPT_CONF_TYPE = Nothing STAN_DEFAULT_ADAPT_CONF = nothing -@static if isdefined(Turing, :CmdStan) +@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin DEFAULT_ADAPT_CONF_TYPE = Union{DEFAULT_ADAPT_CONF_TYPE, CmdStan.Adapt} STAN_DEFAULT_ADAPT_CONF = CmdStan.Adapt() end @@ -101,17 +101,12 @@ Sampler(alg::Hamiltonian, adapt_conf::DEFAULT_ADAPT_CONF_TYPE) = begin end function sample(model::Model, alg::Hamiltonian; - chunk_size=CHUNKSIZE[], # set temporary chunk size save_state=false, # flag for state saving resume_from=nothing, # chain to continue reuse_spl_n=0, # flag for spl re-using adapt_conf=STAN_DEFAULT_ADAPT_CONF, # adapt configuration ) - if ADBACKEND[] == :forward_diff - default_chunk_size = CHUNKSIZE[] # record global chunk size - setchunksize(chunk_size) # set temp chunk size - end - + spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg, adapt_conf) @@ -190,10 +185,6 @@ function sample(model::Model, alg::Hamiltonian; println(" pre-cond. metric = $(std_str).") end - if ADBACKEND[] == :forward_diff - setchunksize(default_chunk_size) # revert global chunk size - end - if resume_from != nothing # concat samples pushfirst!(samples, resume_from.value2...) end diff --git a/test/compiler.jl/newinterface.jl b/test/compiler.jl/newinterface.jl index 286489ebe..1d0e74798 100644 --- a/test/compiler.jl/newinterface.jl +++ b/test/compiler.jl/newinterface.jl @@ -22,4 +22,4 @@ end # # chain = sample(newinterface, HMC(100, 1.5, 3)) -chain = sample(newinterface(obs), HMC(100, 0.75, 3, :p, :x); chunk_size=2) +chain = sample(newinterface(obs), HMC{Turing.ForwardDiffAD{2}}(100, 0.75, 3, :p, :x))