Skip to content

Commit

Permalink
Get ready for AHMC v0.2.0 (#849)
Browse files Browse the repository at this point in the history
* grad -> (val, grad)

* use graident function from ad.jl

* pass progress kw argument to AHMC

* change default adaptor of NUTS to diag

* update transition for Gibbs impl

* update returns stats

* fix interface

* update AHMC in env

* default of NUTS stepsize to 0.0 to trigger find_good_eps

* trigger find_good_eps by default

* Update Project.toml
  • Loading branch information
xukai92 authored and yebai committed Jul 24, 2019
1 parent 2b4ca5d commit 3ee76b8
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 67 deletions.
41 changes: 20 additions & 21 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# This file is machine-generated - editing it directly is not advised

[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
deps = ["LinearAlgebra"]
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.2"
version = "1.0.0"

[[AdvancedHMC]]
deps = ["ArgCheck", "InplaceOps", "LazyArrays", "LinearAlgebra", "Parameters", "ProgressMeter", "Random", "Statistics", "StatsBase"]
git-tree-sha1 = "4db0bda4006fbb9e99e2b4e5e42b804104a6bfb5"
git-tree-sha1 = "66547521b1d25c2fc5af8076b5e55aa148eba033"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.1.8"
version = "0.2.0"

[[ArgCheck]]
deps = ["Random"]
Expand Down Expand Up @@ -53,9 +53,9 @@ version = "0.5.4"

[[CSTParser]]
deps = ["Tokenize"]
git-tree-sha1 = "376a39f1862000442011390f1edf5e7f4dcc7142"
git-tree-sha1 = "0ff80f68f55fcde2ed98d7b24d7abaf20727f3f8"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.0"
version = "0.6.1"

[[CommonSubexpressions]]
deps = ["Test"]
Expand All @@ -76,10 +76,10 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
version = "0.17.0"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -146,10 +146,9 @@ uuid = "8197267c-284f-5f27-9208-e0e47529a953"
version = "0.3.1"

[[IterTools]]
deps = ["SparseArrays", "Test"]
git-tree-sha1 = "79246285c43602384e6f1943b3554042a3712056"
git-tree-sha1 = "2ebe60d7343962966d1779a74a760f13217a6901"
uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
version = "1.1.1"
version = "1.2.0"

[[LazyArrays]]
deps = ["FillArrays", "LinearAlgebra", "MacroTools", "StaticArrays", "Test"]
Expand Down Expand Up @@ -232,10 +231,10 @@ uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.7"

[[Parameters]]
deps = ["Markdown", "OrderedCollections", "REPL", "Test"]
git-tree-sha1 = "70bdbfb2bceabb15345c0b54be4544813b3444e4"
deps = ["OrderedCollections"]
git-tree-sha1 = "1dfd7cd50a8eb06ef693a4c2bbe945943cd000c5"
uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a"
version = "0.10.3"
version = "0.11.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand All @@ -252,10 +251,10 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "0.9.0"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "438630b843c210b375b2a246329200c113acc61b"
deps = ["DataStructures", "LinearAlgebra", "Test"]
git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.1.0"
version = "2.0.3"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
Expand Down Expand Up @@ -366,9 +365,9 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
git-tree-sha1 = "0de343efc07da00cd449d5b04e959ebaeeb3305d"
git-tree-sha1 = "c8a8b00ae44a94950814ff77850470711a360225"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.4"
version = "0.5.5"

[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.6.19"
version = "0.6.21"

[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Expand Down
3 changes: 0 additions & 3 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,6 @@ end
# VarInfo, combined with spl.info, to Sample
function Sample(vi::AbstractVarInfo, spl::Sampler)
s = Sample(vi)
if haskey(spl.info, :adaptor)
s.value[:lf_eps] = AHMC.getϵ(spl.info[:adaptor])
end
if haskey(spl.info, :eval_num)
s.value[:eval_num] = spl.info[:eval_num]
end
Expand Down
75 changes: 34 additions & 41 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function HMCDA{AD}(
n_iters::Int,
δ::Float64,
λ::Float64;
init_ϵ::Float64=0.1,
init_ϵ::Float64=0.0,
metricT=AHMC.UnitEuclideanMetric
) where AD
n_adapts_default = Int(round(n_iters / 2))
Expand All @@ -126,7 +126,7 @@ function HMCDA{AD}(
δ::Float64,
λ::Float64,
space::Symbol...;
init_ϵ::Float64=0.1,
init_ϵ::Float64=0.0,
metricT=AHMC.UnitEuclideanMetric
) where AD
return HMCDA{AD}(n_iters, n_adapts, δ, λ, init_ϵ, metricT, space)
Expand Down Expand Up @@ -193,8 +193,8 @@ function NUTS{AD}(
space::Symbol...;
max_depth::Int=5,
Δ_max::Float64=1000.0,
init_ϵ::Float64=0.1,
metricT=AHMC.DenseEuclideanMetric
init_ϵ::Float64=0.0,
metricT=AHMC.DiagEuclideanMetric
) where AD
NUTS{AD}(n_iters, n_adapts, δ, max_depth, Δ_max, init_ϵ, metricT, space)
end
Expand All @@ -204,8 +204,8 @@ function NUTS{AD}(
δ::Float64;
max_depth::Int=5,
Δ_max::Float64=1000.0,
init_ϵ::Float64=0.1,
metricT=AHMC.DenseEuclideanMetric
init_ϵ::Float64=0.0,
metricT=AHMC.DiagEuclideanMetric
) where AD
n_adapts_default = Int(round(n_iters / 2))
NUTS{AD}(n_iters, n_adapts_default > 1000 ?
Expand Down Expand Up @@ -242,6 +242,7 @@ function sample(
rng::AbstractRNG=GLOBAL_RNG,
discard_adapt::Bool=true,
verbose::Bool=true,
progress::Bool=false,
kwargs...
)
# Create sampler
Expand Down Expand Up @@ -296,8 +297,8 @@ function sample(
step(model, spl, vi, Val(true); adaptor=adaptor)

# Sampling using AHMC and store samples in `samples`
steps!(model, spl, vi, samples; rng=rng, verbose=verbose)
steps!(model, spl, vi, samples; rng=rng, verbose=verbose, progress=progress)

# Concatenate samples
if resume_from != nothing
pushfirst!(samples, resume_from.info[:samples]...)
Expand Down Expand Up @@ -362,6 +363,9 @@ function step(
init_ϵ = AHMC.find_good_eps(h, θ_init)
@info "Found initial step size" init_ϵ
end
if AHMC.getϵ(adaptor) == 0.0
adaptor = AHMCAdaptor(spl.alg; init_ϵ=init_ϵ)
end

spl.info[:h] = h
spl.info[:traj] = gen_traj(spl.alg, init_ϵ)
Expand Down Expand Up @@ -431,26 +435,30 @@ end


# Efficient multiple step sampling for adaptive HMC.
function steps!(model,
function steps!(
model,
spl::Sampler{<:AdaptiveHamiltonian},
vi,
samples;
rng::AbstractRNG=GLOBAL_RNG,
verbose::Bool=true
verbose::Bool=true,
progress::Bool=false
)
ahmc_samples = AHMC.sample(
ahmc_samples, stats = AHMC.sample(
rng,
spl.info[:h],
spl.info[:traj],
Vector{Float64}(vi[spl]),
spl.alg.n_iters,
spl.info[:adaptor],
spl.alg.n_adapts;
verbose=verbose
verbose=verbose,
progress=progress
)
for i = 1:length(samples)
vi[spl] = ahmc_samples[i]
samples[i].value = Sample(vi, spl).value
foreach(name -> samples[i].value[name] = stats[i][name], typeof(stats[i]).names)
end
end

Expand All @@ -461,19 +469,23 @@ function steps!(
vi,
samples;
rng::AbstractRNG=GLOBAL_RNG,
verbose::Bool=true
verbose::Bool=true,
progress::Bool=false
)
ahmc_samples = AHMC.sample(
ahmc_samples, stats = AHMC.sample(
rng,
spl.info[:h],
spl.info[:traj],
Vector{Float64}(vi[spl]),
spl.alg.n_iters;
verbose=verbose
verbose=verbose,
progress=progress
)

for i = 1:length(samples)
vi[spl] = ahmc_samples[i]
samples[i].value = Sample(vi, spl).value
foreach(name -> samples[i].value[name] = stats[i][name], typeof(stats[i]).names)
end
end

Expand All @@ -484,7 +496,8 @@ function steps!(
vi,
samples;
rng::AbstractRNG=GLOBAL_RNG,
verbose::Bool=true
verbose::Bool=true,
progress::Bool=false
)
# Init step
time_elapsed = @elapsed vi, is_accept = step(model, spl, vi, Val(true))
Expand Down Expand Up @@ -515,11 +528,7 @@ gradient at `θ` for the model specified by `(vi, spl, model)`.
"""
function gen_∂logπ∂θ(vi::VarInfo, spl::Sampler, model)
function ∂logπ∂θ(x)
x_old, lj_old = vi[spl], vi.logp
_, deriv = gradient_logp(x, vi, model, spl)
vi[spl] = x_old
setlogp!(vi, lj_old)
return deriv
return gradient_logp(x, vi, model, spl)
end
return ∂logπ∂θ
end
Expand Down Expand Up @@ -576,26 +585,10 @@ function hmc_step(
# Build phase point
z = AHMC.phasepoint(h, θ, r)

# TODO: remove below when we can get is_accept from AHMC.transition
H = AHMC.neg_energy(z) # NOTE: this a waste of computation

# Call AHMC to make one MCMC transition
z_new, α = AHMC.transition(traj, h, z)

# Compute new Hamiltonian energy
H_new = AHMC.neg_energy(z_new)
θ_new = z_new.θ

# NOTE: as `transition` doesn't return `is_accept`,
# I use `H == H_new` to check if the sample is accepted.
is_accept = H != H_new # If the new Hamiltonian enerygy is different
# from the old one, the sample was accepted.
alg isa NUTS && (is_accept = true) # we always accept in NUTS

# Compute updated log-joint probability
lj_new = logπ(θ_new)
z_new, stat = AHMC.transition(traj, h, z)

return θ_new, lj_new, is_accept, α
return z_new.θ, stat.log_density, stat.is_accept, stat.acceptance_rate
end

####
Expand Down Expand Up @@ -669,9 +662,9 @@ observe(spl::Sampler{<:Hamiltonian},
#### Default HMC stepsize and mass matrix adaptor
####

function AHMCAdaptor(alg::AdaptiveHamiltonian)
function AHMCAdaptor(alg::AdaptiveHamiltonian; init_ϵ=alg.init_ϵ)
p = AHMC.Preconditioner(getmetricT(alg))
nda = AHMC.NesterovDualAveraging(alg.δ, alg.init_ϵ)
nda = AHMC.NesterovDualAveraging(alg.δ, init_ϵ)
if getmetricT(alg) == AHMC.UnitEuclideanMetric
adaptor = AHMC.NaiveHMCAdaptor(p, nda)
else
Expand Down
5 changes: 4 additions & 1 deletion src/utilities/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ end
#########

# Variables to put in the Chains :internal section.
const _internal_vars = ["elapsed", "eval_num", "lf_eps", "lp"]
const _internal_vars = [
"elapsed", "eval_num", "lf_eps", "lp",
"acceptance_rate", "hamiltonian_energy", "is_accept", "log_density", "n_steps", "numerical_error", "step_size", "tree_depth",
]

function Chain(w::Real, s::AbstractArray{Sample})
samples = flatten.(s)
Expand Down

0 comments on commit 3ee76b8

Please sign in to comment.