Skip to content

Commit

Permalink
Improve numerical stability (#619)
Browse files Browse the repository at this point in the history
* remove unnecessary lower bound check

* make default precond to diagonal

* add numerical stable Binomial with logit

* fix DA complete adapt

* add test of BinomialLogit

* turn off pre-cond adapt by default
  • Loading branch information
xukai92 authored and yebai committed Dec 9, 2018
1 parent f27f2c9 commit 4dcea02
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 14 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Libtask 0.1.1
Flux 0.6.7
MacroTools
StatsFuns 0.7.0
SpecialFunctions
Bijectors

ProgressMeter 0.6.0
Expand Down
3 changes: 2 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using ForwardDiff
using Bijectors
@reexport using MCMCChain
using StatsFuns
using SpecialFunctions
using Statistics
using LinearAlgebra
using ProgressMeter
Expand Down Expand Up @@ -169,7 +170,7 @@ export TArray, tzeros, localcopy, IArray

export @sym_str

export Flat, FlatPos
export Flat, FlatPos, BinomialLogit, VecBinomialLogit

##################
# Inference code #
Expand Down
24 changes: 24 additions & 0 deletions src/models/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,27 @@ Distributions.rand(d::FlatPos, n::Int) = Vector([rand() for _ = 1:n] .+ d.l)
function Distributions.logpdf(d::FlatPos, x::AbstractVector{<:Real})
return any(x .<= d.l) ? -Inf : zero(x)
end

# Binomial with logit
struct BinomialLogit{T<:Real} <: DiscreteUnivariateDistribution
n::Int64
logitp::T
end

struct VecBinomialLogit{T<:Real} <: DiscreteUnivariateDistribution
n::Vector{Int64}
logitp::Vector{T}
end

function logpdf_binomial_logit(n, logitp, k)
logcomb = -StatsFuns.log1p(n) - SpecialFunctions.lbeta(n - k + 1, k + 1)
return logcomb + k * logitp - n * StatsFuns.log1pexp(logitp)
end

function Distributions.logpdf(d::BinomialLogit{<:Real}, k::Int64)
return logpdf_binomial_logit(d.n, d.logitp, k)
end

function Distributions.logpdf(d::VecBinomialLogit{<:Real}, ks::Vector{Int64})
return sum(logpdf_binomial_logit.(d.n, d.logitp, ks))
end
12 changes: 8 additions & 4 deletions src/samplers/adapt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ end

function ThreePhaseAdapter(spl::Sampler{<:AdaptiveHamiltonian}, ϵ::Real, dim::Integer)
# Diagonal pre-conditioner
# pc = DiagPreConditioner(dim)
pc = DensePreConditioner(dim)
# pc = UnitPreConditioner()
pc = DiagPreConditioner(dim)
# pc = DensePreConditioner(dim)
# Dual averaging for step size
ssa = DualAveraging(spl, spl.info[:adapt_conf], ϵ)
# Window parameters
Expand Down Expand Up @@ -84,11 +85,14 @@ function adapt!(tp::ThreePhaseAdapter, stats::Real, θ; adapt_ϵ=false, adapt_M=
if tp.state.n < tp.n_adapts
tp.state.n += 1
if tp.state.n == tp.n_adapts
if adapt_ϵ
tp.ssa.state.ϵ = exp(tp.ssa.state.x_bar)
end
@info " Adapted ϵ = $(getss(tp)), std = $(string(tp.pc)); $(tp.state.n) iterations is used for adaption."
else
if adapt_ϵ
is_updateϵ = is_windowend(tp) || tp.state.n == tp.n_adapts
adapt!(tp.ssa, stats, is_updateϵ)
is_updateμ = is_windowend(tp)# || tp.state.n == tp.n_adapts
adapt!(tp.ssa, stats, is_updateμ)
end

# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Expand Down
15 changes: 7 additions & 8 deletions src/samplers/adapt/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ end
function adapt_stepsize!(da::DualAveraging, stats::Real)
@debug "adapting step size ϵ..."
@debug "current α = $(stats)"
da.state.m = da.state.m + 1
da.state.m += 1
m = da.state.m

# Clip average MH acceptance probability.
Expand All @@ -90,20 +90,19 @@ function adapt_stepsize!(da::DualAveraging, stats::Real)
ϵ = exp(x)
@debug "new ϵ = $(ϵ), old ϵ = $(da.state.ϵ)"

if isnan(ϵ) || isinf(ϵ) || ϵ <= 1e-3
if isnan(ϵ) || isinf(ϵ)
@warn "Incorrect ϵ = ; ϵ_previous = $(da.state.ϵ) is used instead."
else
da.state.ϵ = ϵ
da.state.x_bar, da.state.H_bar = x_bar, H_bar
end
da.state.x_bar = x_bar
da.state.H_bar = H_bar
end

function adapt!(da::DualAveraging, stats::Real, is_updateϵ::Bool)
function adapt!(da::DualAveraging, stats::Real, is_updateμ::Bool)
adapt_stepsize!(da, stats)
if is_updateϵ
ϵ = exp(da.state.x_bar)
da.state.ϵ = ϵ
da.state.μ = computeμ(ϵ)
if is_updateμ
da.state.μ = computeμ(da.state.ϵ)
reset!(da.state)
end
end
2 changes: 1 addition & 1 deletion src/samplers/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
end

if spl.alg isa AdaptiveHamiltonian
adapt!(spl.info[:wum], α, vi[spl], adapt_M=true, adapt_ϵ=true)
adapt!(spl.info[:wum], α, vi[spl], adapt_M=false, adapt_ϵ=true)
end

@debug "R -> X..."
Expand Down
11 changes: 11 additions & 0 deletions test/models.jl/distributions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Test
using Turing: BinomialLogit
using Distributions: Binomial, logpdf
using StatsFuns: logistic

n = 10
logitp = randn()
d1 = BinomialLogit(n, logitp)
d2 = Binomial(n, logistic(logitp))
k = 3
@test logpdf(d1, k) logpdf(d2, k)
1 change: 1 addition & 0 deletions test/models.jl/skip_tests
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests to skip

0 comments on commit 4dcea02

Please sign in to comment.