Skip to content

Commit

Permalink
Unify log density function types (#1846)
Browse files Browse the repository at this point in the history
* Unify log density function types

* Some fixes

* More fixes

* Some more fixes

* Another fix

* Apply suggestions from code review

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update src/contrib/inference/dynamichmc.jl

* Update OptimInterface.jl

* Fix implementation of Optim interface

* Update ModeEstimation.jl

* Fix tests

* Update mh.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>
  • Loading branch information
devmotion and torfjelde authored Jun 29, 2022
1 parent bab91b3 commit 9f482f3
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 138 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.21.6"
version = "0.21.7"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -37,7 +37,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractMCMC = "4"
AdvancedHMC = "0.3.0"
AdvancedMH = "0.6"
AdvancedMH = "0.6.8"
AdvancedPS = "0.3.4"
AdvancedVI = "0.1"
BangBang = "0.3"
Expand Down
16 changes: 7 additions & 9 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()

DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space

struct DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo}
model::M
sampler::S
varinfo::V
end
# Only define traits for `DynamicNUTS` sampler to avoid type piracy and surprises
# TODO: Implement generally with `LogDensityProblems`
const DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}

function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity)
return length(ℓ.varinfo[ℓ.sampler])
Expand All @@ -37,7 +35,7 @@ function DynamicHMC.logdensity_and_gradient(
::DynamicHMCLogDensity,
x::AbstractVector,
)
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler)
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler, ℓ.context)
end

"""
Expand All @@ -64,7 +62,7 @@ function gibbs_state(
varinfo::AbstractVarInfo,
)
# Update the previous evaluation.
= DynamicHMCLogDensity(model, spl, varinfo)
= Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())
Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl])
return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize)
end
Expand All @@ -87,7 +85,7 @@ function DynamicPPL.initialstep(
# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
rng,
DynamicHMCLogDensity(model, spl, vi),
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
0;
initialization = (q = vi[spl],),
reporter = DynamicHMC.NoProgressReport(),
Expand Down Expand Up @@ -115,7 +113,7 @@ function AbstractMCMC.step(
)
# Compute next sample.
vi = state.vi
= DynamicHMCLogDensity(model, spl, vi)
= Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
steps = DynamicHMC.mcmc_steps(
rng,
DynamicHMC.NUTS(),
Expand Down
2 changes: 1 addition & 1 deletion src/inference/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function AbstractMCMC.step(
)
# Generate a log joint function.
vi = state.vi
densitymodel = AMH.DensityModel(gen_logπ(vi, SampleFromPrior(), model))
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()))

# Compute the next states.
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
Expand Down
10 changes: 3 additions & 7 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function AbstractMCMC.step(
sample, state = AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi), ESSLogLikelihood(model, spl, vi),
ESSPrior(model, spl, vi), Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
),
EllipticalSliceSampling.ESS(),
oldstate,
Expand Down Expand Up @@ -124,13 +124,9 @@ end
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
struct ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
model::M
sampler::S
varinfo::V
end
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext()}

function (ℓ::ESSLogLikelihood)(f)
function (ℓ::ESSLogLikelihood)(f::AbstractVector)
sampler =.sampler
varinfo = setindex!!(ℓ.varinfo, f, sampler)
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))
Expand Down
28 changes: 3 additions & 25 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ function DynamicPPL.initialstep(
metricT = getmetricT(spl.alg)
metric = metricT(length(theta))
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
logπ = gen_logπ(vi, spl, model)
logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)

# Compute phase point z.
Expand Down Expand Up @@ -262,7 +262,7 @@ end

function get_hamiltonian(model, spl, vi, state, n)
metric = gen_metric(n, spl, state)
ℓπ = gen_logπ(vi, spl, model)
ℓπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
∂ℓπ∂θ = gen_∂logπ∂θ(vi, spl, model)
return AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
end
Expand Down Expand Up @@ -435,28 +435,6 @@ function gen_∂logπ∂θ(vi, spl::Sampler, model)
return ∂logπ∂θ
end

"""
gen_logπ(vi, spl::Sampler, model)
Generate a function that takes `θ` and returns logpdf at `θ` for the model specified by
`(vi, spl, model)`.
"""
function gen_logπ(vi_base, spl::AbstractSampler, model)
function logπ(x)::Float64
vi = vi_base
x_old, lj_old = vi[spl], getlogp(vi)
vi = setindex!!(vi, x, spl)
vi = last(DynamicPPL.evaluate!!(model, vi, spl))
lj = getlogp(vi)
# Don't really need to capture these will only be
# necessary if `vi` is indeed mutable.
setindex!!(vi, x_old, spl)
setlogp!!(vi, lj_old)
return lj
end
return logπ
end

gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim)
function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
Expand Down Expand Up @@ -567,7 +545,7 @@ function HMCState(

# Get the initial log pdf and gradient functions.
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
logπ = gen_logπ(vi, spl, model)
logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())

# Get the metric type.
metricT = getmetricT(spl.alg)
Expand Down
16 changes: 6 additions & 10 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,15 @@ A log density function for the MH sampler.
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
"""
struct MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} <: Function # Relax AMH.DensityModel?
model::M
sampler::S
vi::V
end
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}

function (f::MHLogDensityFunction)(x)
function (f::MHLogDensityFunction)(x::NamedTuple)
sampler = f.sampler
vi = f.vi
vi = f.varinfo

x_old, lj_old = vi[sampler], getlogp(vi)
set_namedtuple!(vi, x)
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, DynamicPPL.DefaultContext()))
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context))
lj = getlogp(vi_new)

# Reset old `vi`.
Expand Down Expand Up @@ -376,7 +372,7 @@ function propose!(
prev_trans = AMH.Transition(vt, getlogp(vi))

# Make a new transition.
densitymodel = AMH.DensityModel(MHLogDensityFunction(model, spl, vi))
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

# TODO: Make this compatible with immutable `VarInfo`.
Expand Down Expand Up @@ -404,7 +400,7 @@ function propose!(
prev_trans = AMH.Transition(vals, getlogp(vi))

# Make a new transition.
densitymodel = AMH.DensityModel(gen_logπ(vi, spl, model))
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

# TODO: Make this compatible with immutable `VarInfo`.
Expand Down
92 changes: 37 additions & 55 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,74 +78,57 @@ end
"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
A struct that stores the log density function of a `DynamicPPL` model.
A struct that stores the negative log density function of a `DynamicPPL` model.
"""
struct OptimLogDensity{M<:Model,C<:AbstractContext,V<:VarInfo}
"A `DynamicPPL.Model` constructed either with the `@model` macro or manually."
model::M
"A `DynamicPPL.AbstractContext` used to evaluate the model. `LikelihoodContext` or `DefaultContext` are typical for MAP/MLE."
context::C
"A `DynamicPPL.VarInfo` struct that will be used to update model parameters."
vi::V
end
const OptimLogDensity{M<:Model,C<:OptimizationContext,V<:VarInfo} = Turing.LogDensityFunction{V,M,DynamicPPL.SampleFromPrior,C}

"""
OptimLogDensity(model::Model, context::AbstractContext)
OptimLogDensity(model::Model, context::OptimizationContext)
Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`.
"""
function OptimLogDensity(model::Model, context::AbstractContext)
function OptimLogDensity(model::Model, context::OptimizationContext)
init = VarInfo(model)
return OptimLogDensity(model, context, init)
return Turing.LogDensityFunction(init, model, DynamicPPL.SampleFromPrior(), context)
end

"""
(f::OptimLogDensity)(z)
Evaluate the log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
Evaluate the negative log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
at the array `z`.
"""
function (f::OptimLogDensity)(z)
spl = DynamicPPL.SampleFromPrior()

varinfo = DynamicPPL.VarInfo(f.vi, spl, z)
f.model(varinfo, spl, f.context)
return -DynamicPPL.getlogp(varinfo)
function (f::OptimLogDensity)(z::AbstractVector)
sampler = f.sampler
varinfo = DynamicPPL.VarInfo(f.varinfo, sampler, z)
return -getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, sampler, f.context)))
end

function (f::OptimLogDensity)(F, G, H, z)
# Throw an error if a second order method was used.
if H !== nothing
error("Second order optimization is not yet supported.")
end

spl = DynamicPPL.SampleFromPrior()

function (f::OptimLogDensity)(F, G, z)
if G !== nothing
# Calculate log joint and the gradient
l, g = Turing.gradient_logp(
# Calculate negative log joint and its gradient.
sampler = f.sampler
neglogp, ∇neglogp = Turing.gradient_logp(
z,
DynamicPPL.VarInfo(f.vi, spl, z),
DynamicPPL.VarInfo(f.varinfo, sampler, z),
f.model,
spl,
f.context
sampler,
f.context,
)

# Use the negative gradient because we are minimizing.
G[:] = -g
# Save the gradient to the pre-allocated array.
copyto!(G, ∇neglogp)

# If F is something, return that since we already have the
# log joint.
# If F is something, the negative log joint is requested as well.
# We have already computed it as a by-product above and hence return it directly.
if F !== nothing
F = -l
return F
return neglogp
end
end

# No gradient necessary, just return the log joint.
# Only negative log joint requested but no gradient.
if F !== nothing
F = f(z)
return F
return f(z)
end

return nothing
Expand All @@ -158,16 +141,16 @@ end
#################################################

function transform!(f::OptimLogDensity)
spl = DynamicPPL.SampleFromPrior()
spl = f.sampler

## Check link status of vi in OptimLogDensity
linked = DynamicPPL.islinked(f.vi, spl)
linked = DynamicPPL.islinked(f.varinfo, spl)

## transform into constrained or unconstrained space depending on current state of vi
if !linked
DynamicPPL.link!(f.vi, spl)
DynamicPPL.link!(f.varinfo, spl)
else
DynamicPPL.invlink!(f.vi, spl)
DynamicPPL.invlink!(f.varinfo, spl)
end

return nothing
Expand Down Expand Up @@ -249,8 +232,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MAP, ::constrained_space{fa
obj = OptimLogDensity(model, ctx)

transform!(obj)
init = Init(obj.vi, constrained_space{false}())
t = ParameterTransform(obj.vi, constrained_space{true}())
init = Init(obj.varinfo, constrained_space{false}())
t = ParameterTransform(obj.varinfo, constrained_space{true}())

return (obj=obj, init = init, transform=t)
end
Expand All @@ -259,8 +242,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MAP, ::constrained_space{tr
ctx = OptimizationContext(DynamicPPL.DefaultContext())
obj = OptimLogDensity(model, ctx)

init = Init(obj.vi, constrained_space{true}())
t = ParameterTransform(obj.vi, constrained_space{true}())
init = Init(obj.varinfo, constrained_space{true}())
t = ParameterTransform(obj.varinfo, constrained_space{true}())

return (obj=obj, init = init, transform=t)
end
Expand All @@ -270,8 +253,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MLE, ::constrained_space{f
obj = OptimLogDensity(model, ctx)

transform!(obj)
init = Init(obj.vi, constrained_space{false}())
t = ParameterTransform(obj.vi, constrained_space{true}())
init = Init(obj.varinfo, constrained_space{false}())
t = ParameterTransform(obj.varinfo, constrained_space{true}())

return (obj=obj, init = init, transform=t)
end
Expand All @@ -280,8 +263,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MLE, ::constrained_space{tr
ctx = OptimizationContext(DynamicPPL.LikelihoodContext())
obj = OptimLogDensity(model, ctx)

init = Init(obj.vi, constrained_space{true}())
t = ParameterTransform(obj.vi, constrained_space{true}())
init = Init(obj.varinfo, constrained_space{true}())
t = ParameterTransform(obj.varinfo, constrained_space{true}())

return (obj=obj, init = init, transform=t)
end
Expand Down Expand Up @@ -309,8 +292,7 @@ function optim_function(
else
OptimizationFunction(
l;
grad = (G,x,p) -> obj(nothing, G, nothing, x),
hess = (H,x,p) -> obj(nothing, nothing, H, x),
grad = (G,x,p) -> obj(nothing, G, x),
)
end

Expand Down
Loading

2 comments on commit 9f482f3

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63322

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.21.7 -m "<description of version>" 9f482f3c65c52485c16f511b56f622198aee3d2d
git push origin v0.21.7

Please sign in to comment.