From 28a7c22ec70b1a07a36baa343b722012c2d59396 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 17 Dec 2020 18:37:24 +0100 Subject: [PATCH] Update EllipticalSliceSampling (#1492) --- Project.toml | 4 +- src/inference/ess.jl | 97 +++++++++++++++++++++++++------------------- 2 files changed, 58 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index 9d8dea064..c8f9a794b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.15.4" +version = "0.15.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -44,7 +44,7 @@ Distributions = "0.23.3, 0.24" DistributionsAD = "0.6" DocStringExtensions = "0.8" DynamicPPL = "0.10.2" -EllipticalSliceSampling = "0.3" +EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" Libtask = "0.4, 0.5" LogDensityProblems = "^0.9, 0.10" diff --git a/src/inference/ess.jl b/src/inference/ess.jl index fbe15bb07..d3c7021cf 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -64,67 +64,82 @@ function AbstractMCMC.step( end # define previous sampler state - oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi)) + # (do not use cache to avoid in-place sampling from prior) + oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing) # compute next state - _, state = AbstractMCMC.step(rng, ESSModel(model, spl, vi), - EllipticalSliceSampling.ESS(), oldstate) + sample, state = AbstractMCMC.step( + rng, + EllipticalSliceSampling.ESSModel( + ESSPrior(model, spl, vi), ESSLogLikelihood(model, spl, vi), + ), + EllipticalSliceSampling.ESS(), + oldstate, + ) # update sample and log-likelihood - vi[spl] = state.sample + vi[spl] = sample setlogp!(vi, state.loglikelihood) return Transition(vi), vi end -struct ESSModel{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} <: AbstractMCMC.AbstractModel +# Prior distribution of considered random variable +struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} model::M - spl::S - vi::V + sampler::S + varinfo::V μ::T -end - -function ESSModel(model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo) - vns = _getvns(vi, spl) - μ = mapreduce(vcat, vns[1]) do vn - dist = getdist(vi, vn) - vectorize(dist, mean(dist)) + + function ESSPrior{M,S,V}(model::M, sampler::S, varinfo::V) where { + M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo + } + vns = _getvns(varinfo, sampler) + μ = mapreduce(vcat, vns[1]) do vn + dist = getdist(varinfo, vn) + EllipticalSliceSampling.isgaussian(typeof(dist)) || + error("[ESS] only supports Gaussian prior distributions") + vectorize(dist, mean(dist)) + end + return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ) end - - ESSModel(model, spl, vi, μ) end -# sample from the prior -function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel) - spl = model.spl - vi = model.vi - vns = _getvns(vi, spl) - set_flag!(vi, vns[1][1], "del") - model.model(rng, vi, spl) - return vi[spl] +function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo) + return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}( + model, sampler, varinfo, + ) end -# compute proposal and apply correction for distributions with nonzero mean -function EllipticalSliceSampling.proposal(model::ESSModel, f, ν, θ) - sinθ, cosθ = sincos(θ) - a = 1 - (sinθ + cosθ) - return @. cosθ * f + sinθ * ν + a * model.μ +# Ensure that the prior is a Gaussian distribution (checked in the constructor) +EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true + +# Only define out-of-place sampling +function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) + sampler = p.sampler + varinfo = p.varinfo + vns = _getvns(varinfo, sampler) + set_flag!(varinfo, vns[1][1], "del") + p.model(rng, varinfo, sampler) + return varinfo[sampler] end -function EllipticalSliceSampling.proposal!(out, model::ESSModel, f, ν, θ) - sinθ, cosθ = sincos(θ) - a = 1 - (sinθ + cosθ) - @. out = cosθ * f + sinθ * ν + a * model.μ - return out +# Mean of prior distribution +Distributions.mean(p::ESSPrior) = p.μ + +# Evaluate log-likelihood of proposals +struct ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} + model::M + sampler::S + varinfo::V end -# evaluate log-likelihood -function Distributions.loglikelihood(model::ESSModel, f) - spl = model.spl - vi = model.vi - vi[spl] = f - model.model(vi, spl) - getlogp(vi) +function (ℓ::ESSLogLikelihood)(f) + sampler = ℓ.sampler + varinfo = ℓ.varinfo + varinfo[sampler] = f + ℓ.model(varinfo, sampler) + return getlogp(varinfo) end function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)