Skip to content

Commit

Permalink
Update EllipticalSliceSampling (#1492)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Dec 17, 2020
1 parent 6d1562a commit 28a7c22
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.15.4"
version = "0.15.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -44,7 +44,7 @@ Distributions = "0.23.3, 0.24"
DistributionsAD = "0.6"
DocStringExtensions = "0.8"
DynamicPPL = "0.10.2"
EllipticalSliceSampling = "0.3"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5"
LogDensityProblems = "^0.9, 0.10"
Expand Down
97 changes: 56 additions & 41 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,67 +64,82 @@ function AbstractMCMC.step(
end

# define previous sampler state
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi))
# (do not use cache to avoid in-place sampling from prior)
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)

# compute next state
_, state = AbstractMCMC.step(rng, ESSModel(model, spl, vi),
EllipticalSliceSampling.ESS(), oldstate)
sample, state = AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi), ESSLogLikelihood(model, spl, vi),
),
EllipticalSliceSampling.ESS(),
oldstate,
)

# update sample and log-likelihood
vi[spl] = state.sample
vi[spl] = sample
setlogp!(vi, state.loglikelihood)

return Transition(vi), vi
end

struct ESSModel{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} <: AbstractMCMC.AbstractModel
# Prior distribution of considered random variable
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
model::M
spl::S
vi::V
sampler::S
varinfo::V
μ::T
end

function ESSModel(model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo)
vns = _getvns(vi, spl)
μ = mapreduce(vcat, vns[1]) do vn
dist = getdist(vi, vn)
vectorize(dist, mean(dist))

function ESSPrior{M,S,V}(model::M, sampler::S, varinfo::V) where {
M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo
}
vns = _getvns(varinfo, sampler)
μ = mapreduce(vcat, vns[1]) do vn
dist = getdist(varinfo, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
vectorize(dist, mean(dist))
end
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
end

ESSModel(model, spl, vi, μ)
end

# sample from the prior
function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel)
spl = model.spl
vi = model.vi
vns = _getvns(vi, spl)
set_flag!(vi, vns[1][1], "del")
model.model(rng, vi, spl)
return vi[spl]
function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(
model, sampler, varinfo,
)
end

# compute proposal and apply correction for distributions with nonzero mean
function EllipticalSliceSampling.proposal(model::ESSModel, f, ν, θ)
sinθ, cosθ = sincos(θ)
a = 1 - (sinθ + cosθ)
return @. cosθ * f + sinθ * ν + a * model.μ
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true

# Only define out-of-place sampling
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
vns = _getvns(varinfo, sampler)
set_flag!(varinfo, vns[1][1], "del")
p.model(rng, varinfo, sampler)
return varinfo[sampler]
end

function EllipticalSliceSampling.proposal!(out, model::ESSModel, f, ν, θ)
sinθ, cosθ = sincos(θ)
a = 1 - (sinθ + cosθ)
@. out = cosθ * f + sinθ * ν + a * model.μ
return out
# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
struct ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
model::M
sampler::S
varinfo::V
end

# evaluate log-likelihood
function Distributions.loglikelihood(model::ESSModel, f)
spl = model.spl
vi = model.vi
vi[spl] = f
model.model(vi, spl)
getlogp(vi)
function (ℓ::ESSLogLikelihood)(f)
sampler =.sampler
varinfo =.varinfo
varinfo[sampler] = f
.model(varinfo, sampler)
return getlogp(varinfo)
end

function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
Expand Down

2 comments on commit 28a7c22

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/26568

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.15.5 -m "<description of version>" 28a7c22ec70b1a07a36baa343b722012c2d59396
git push origin v0.15.5

Please sign in to comment.