Skip to content

Commit

Permalink
Fix DynamicHMC support and related issues (#648)
Browse files Browse the repository at this point in the history
* fix DynamicHMC support and related issues

* use Requires for LogDensityProblems
  • Loading branch information
mohamed82008 authored and yebai committed Jan 19, 2019
1 parent a73e2ba commit d7bf549
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 82 deletions.
37 changes: 19 additions & 18 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,31 @@ using Markdown
using Libtask
using MacroTools

function __init__()
@require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin
using CmdStan
import CmdStan: Adapt, Hmc
end
@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin
using CmdStan
import CmdStan: Adapt, Hmc
end

@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval begin
using DynamicHMC, LogDensityProblems
using LogDensityProblems: AbstractLogDensityProblem, ValueGradient
@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval begin
using .DynamicHMC: NUTS_init_tune_mcmc
end

struct FunctionLogDensity{F} <: AbstractLogDensityProblem
dimension::Int
f::F
end
@init @require LogDensityProblems="6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @eval begin
using .LogDensityProblems: AbstractLogDensityProblem, ValueGradient
struct FunctionLogDensity{F} <: AbstractLogDensityProblem
dimension::Int
f::F
end

LogDensityProblems.dimension(ℓ::FunctionLogDensity) =.dimension
LogDensityProblems.dimension(ℓ::FunctionLogDensity) =.dimension

LogDensityProblems.logdensity(::Type{ValueGradient}, ℓ::FunctionLogDensity, x) =.f(x)::ValueGradient
end
LogDensityProblems.logdensity(::Type{ValueGradient}, ℓ::FunctionLogDensity, x) =.f(x)::ValueGradient
end

import Base: ~, convert, promote_rule, rand, getindex, setindex!
import Distributions: sample
import ForwardDiff: gradient
using Flux.Tracker
using ForwardDiff: ForwardDiff
using Flux.Tracker: Tracker
import MCMCChain: AbstractChains, Chains

##############################
Expand Down Expand Up @@ -138,7 +139,7 @@ struct SampleFromPrior <: AbstractSampler end
const AnySampler = Union{Nothing, AbstractSampler}

include("utilities/resample.jl")
@static if isdefined(Turing, :CmdStan)
@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin
include("utilities/stan-interface.jl")
end
include("utilities/helper.jl")
Expand Down
49 changes: 22 additions & 27 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
abstract type ADBackend end
struct ForwardDiffAD{chunk} <: ADBackend end
getchunksize(::T) where {T <: ForwardDiffAD} = getchunksize(T)
getchunksize(::Type{ForwardDiffAD{chunk}}) where chunk = chunk
getchunksize(::T) where {T <: Hamiltonian} = getchunksize(T)
getchunksize(::Type{<:Hamiltonian{AD}}) where AD = getchunksize(AD)
getchunksize(::T) where {T <: Sampler} = getchunksize(T)
getchunksize(::Type{<:Sampler{T}}) where {T <: Hamiltonian} = getchunksize(T)
getchunksize(::Nothing) = getchunksize(Nothing)
getchunksize(::Type{Nothing}) = CHUNKSIZE[]

struct FluxTrackerAD <: ADBackend end

ADBackend() = ADBackend(ADBACKEND[])
Expand All @@ -12,14 +21,13 @@ function ADBackend(::Val{T}) where {T}
end
end

getchunksize(::Type{ForwardDiffAD{chunk}}) where chunk = Val(chunk)


"""
getADtype(alg)
Finds the autodifferentiation type of the algorithm `alg`.
"""
getADtype(::Nothing) = getADtype(Nothing)
getADtype(::Type{Nothing}) = getADtype()
getADtype() = ADBackend()
getADtype(s::Sampler) = getADtype(typeof(s))
getADtype(s::Type{<:Sampler{TAlg}}) where {TAlg} = getADtype(TAlg)
Expand All @@ -37,30 +45,18 @@ gradient(
Computes the gradient of the log joint of `θ` for the model specified by
`(vi, sampler, model)` using whichever automatic differentation tool is currently active.
"""
@generated function gradient(
function gradient(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::TS=nothing,
) where {TS <: Union{Nothing, Sampler}}
if TS == Nothing
return quote
ad_type = getADtype()
if ad_type <: ForwardDiffAD
gradient_forward(θ, vi, model, sampler, getchunksize(ad_type))
else
gradient_reverse(θ, vi, model, sampler)
end
end
else
ad_type = getADtype(TS)
@assert any(T -> ad_type <: T, (ForwardDiffAD, FluxTrackerAD))
if ad_type <: ForwardDiffAD
chunk = getchunksize(ad_type)
return :(gradient_forward(θ, vi, model, sampler, $chunk))
else ad_type <: FluxTrackerAD
return :(gradient_reverse(θ, vi, model, sampler))
end
sampler::TS,
) where {TS <: Sampler{<:Hamiltonian}}

ad_type = getADtype(TS)
if ad_type <: ForwardDiffAD
return gradient_forward(θ, vi, model, sampler)
else ad_type <: FluxTrackerAD
return gradient_reverse(θ, vi, model, sampler)
end
end

Expand All @@ -70,7 +66,6 @@ gradient_forward(
vi::VarInfo,
model::Model,
spl::Union{Nothing, Sampler}=nothing,
chunk_size::Int=CHUNKSIZE[],
)
Computes the gradient of the log joint of `θ` for the model specified by `(vi, spl, model)`
Expand All @@ -81,8 +76,7 @@ function gradient_forward(
vi::VarInfo,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
::Val{chunk_size}=Val(CHUNKSIZE[]),
) where chunk_size
)
# Record old parameters.
vals_old, logp_old = copy(vi.vals), copy(vi.logp)

Expand All @@ -92,6 +86,7 @@ function gradient_forward(
return -runmodel!(model, vi, sampler).logp
end

chunk_size = getchunksize(sampler)
# Set chunk size and do ForwardMode.
chunk = ForwardDiff.Chunk(min(length(θ), chunk_size))
config = ForwardDiff.GradientConfig(f, θ, chunk)
Expand Down
20 changes: 10 additions & 10 deletions src/samplers/adapt/stan.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
@static if isdefined(Turing, :CmdStan)
function DualAveraging(spl::Sampler{<:AdaptiveHamiltonian}, adapt_conf::CmdStan.Adapt, ϵ::Real)
# Hyper parameters for dual averaging
γ = adapt_conf.gamma
t_0 = adapt_conf.t0
κ = adapt_conf.kappa
δ = adapt_conf.delta
return DualAveraging(γ, t_0, κ, δ, DAState(ϵ))
end
@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin
function DualAveraging(spl::Sampler{<:AdaptiveHamiltonian}, adapt_conf::CmdStan.Adapt, ϵ::Real)
# Hyper parameters for dual averaging
γ = adapt_conf.gamma
t_0 = adapt_conf.t0
κ = adapt_conf.kappa
δ = adapt_conf.delta
return DualAveraging(γ, t_0, κ, δ, DAState(ϵ))
end
end

@static if isdefined(Turing, :CmdStan)
@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin
function get_threephase_params(adapt_conf::CmdStan.Adapt)
init_buffer = adapt_conf.init_buffer
term_buffer = adapt_conf.term_buffer
Expand Down
21 changes: 6 additions & 15 deletions src/samplers/dynamichmc.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct DynamicNUTS{T} <: Hamiltonian
struct DynamicNUTS{AD, T} <: Hamiltonian{AD}
n_iters :: Integer # number of samples
space :: Set{T} # sampling space, emtpy means all
gid :: Integer # group ID
Expand Down Expand Up @@ -27,25 +27,20 @@ end
chn = sample(gdemo(1.5, 2.0), DynamicNUTS(2000))
```
"""
function DynamicNUTS(n_iters::Integer, space...)
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
function DynamicNUTS{AD}(n_iters::Integer, space...) where AD
_space = isa(space, Symbol) ? Set([space]) : Set(space)
DynamicNUTS(n_iters, _space, 0)
DynamicNUTS{AD, eltype(_space)}(n_iters, _space, 0)
end

function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian
return Sampler(alg, Dict{Symbol,Any}())
end

function sample(model::Model,
alg::DynamicNUTS,
chunk_size=CHUNKSIZE[]
) where T <: Hamiltonian
alg::DynamicNUTS{AD},
) where AD

if ADBACKEND[] == :forward_diff
default_chunk_size = CHUNKSIZE[] # record global chunk size
setchunksize(chunk_size) # set temp chunk size
end

spl = Sampler(alg)

n = alg.n_iters
Expand Down Expand Up @@ -75,9 +70,5 @@ function sample(model::Model,
samples[i].value = Sample(vi, spl).value
end

if ADBACKEND[] == :forward_diff
setchunksize(default_chunk_size) # revert global chunk size
end

return Chain(0, samples)
end
13 changes: 2 additions & 11 deletions src/samplers/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
DEFAULT_ADAPT_CONF_TYPE = Nothing
STAN_DEFAULT_ADAPT_CONF = nothing

@static if isdefined(Turing, :CmdStan)
@init @require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" @eval begin
DEFAULT_ADAPT_CONF_TYPE = Union{DEFAULT_ADAPT_CONF_TYPE, CmdStan.Adapt}
STAN_DEFAULT_ADAPT_CONF = CmdStan.Adapt()
end
Expand All @@ -101,17 +101,12 @@ Sampler(alg::Hamiltonian, adapt_conf::DEFAULT_ADAPT_CONF_TYPE) = begin
end

function sample(model::Model, alg::Hamiltonian;
chunk_size=CHUNKSIZE[], # set temporary chunk size
save_state=false, # flag for state saving
resume_from=nothing, # chain to continue
reuse_spl_n=0, # flag for spl re-using
adapt_conf=STAN_DEFAULT_ADAPT_CONF, # adapt configuration
)
if ADBACKEND[] == :forward_diff
default_chunk_size = CHUNKSIZE[] # record global chunk size
setchunksize(chunk_size) # set temp chunk size
end


spl = reuse_spl_n > 0 ?
resume_from.info[:spl] :
Sampler(alg, adapt_conf)
Expand Down Expand Up @@ -190,10 +185,6 @@ function sample(model::Model, alg::Hamiltonian;
println(" pre-cond. metric = $(std_str).")
end

if ADBACKEND[] == :forward_diff
setchunksize(default_chunk_size) # revert global chunk size
end

if resume_from != nothing # concat samples
pushfirst!(samples, resume_from.value2...)
end
Expand Down
2 changes: 1 addition & 1 deletion test/compiler.jl/newinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ end
#
# chain = sample(newinterface, HMC(100, 1.5, 3))

chain = sample(newinterface(obs), HMC(100, 0.75, 3, :p, :x); chunk_size=2)
chain = sample(newinterface(obs), HMC{Turing.ForwardDiffAD{2}}(100, 0.75, 3, :p, :x))

0 comments on commit d7bf549

Please sign in to comment.