Skip to content

Commit

Permalink
reenable PriorContext for Optimization
Browse files Browse the repository at this point in the history
Co-authored-by: Markus Hauru <[email protected]>
  • Loading branch information
alyst and mhauru committed Jun 5, 2024
1 parent 87a040e commit d46825b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.Abstract
context::C

function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext}
if !(context isa Union{DynamicPPL.DefaultContext,DynamicPPL.LikelihoodContext})
if !(context isa Union{DynamicPPL.DefaultContext,DynamicPPL.LikelihoodContext,DynamicPPL.PriorContext})
msg = """
`OptimizationContext` supports only leaf contexts of type
`DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext`
(given: `$(typeof(context)))`
`OptimizationContext` supports only leaf contexts of type
`DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`,
and `DynamicPPL.PriorContext` (given: `$(typeof(context)))`
"""
throw(ArgumentError(msg))
end
Expand All @@ -60,7 +60,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()

function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
r = vi[vn, dist]
lp = if ctx.context isa DynamicPPL.DefaultContext
lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext}
# MAP
Distributions.logpdf(dist, r)
else
Expand All @@ -83,7 +83,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
r = DynamicPPL.get_and_set_val!(
Random.default_rng(), vi, vns, right, DynamicPPL.SampleFromPrior()
)
lp = if ctx.context isa DynamicPPL.DefaultContext
lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext}
# MAP
_loglikelihood(right, r)
else
Expand All @@ -93,6 +93,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
return r, lp, vi
end

DynamicPPL.tilde_observe(ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args...) =
DynamicPPL.tilde_observe(ctx.context, args...)

DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args...) =
DynamicPPL.dot_tilde_observe(ctx.context, args...)

"""
OptimLogDensity{M<:DynamicPPL.Model,C<:Context,V<:DynamicPPL.VarInfo}
Expand Down
16 changes: 16 additions & 0 deletions test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@
@test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) ==
Turing.Optimisation.OptimLogDensity(m2, ctx)(w)
end

@testset "Default, Likelihood, Prior Contexts" begin
m1 = model1(x)
defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext())
a = [0.3]

@test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) ==
Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) +
Turing.Optimisation.OptimLogDensity(m1, prictx)(a)

# test that PriorContext is calculating the right thing
@test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) -Distributions.logpdf(Uniform(0, 2), 0.3)
@test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) -Distributions.logpdf(Uniform(0, 2), -0.3)
end
end

@numerical_testset "gdemo" begin
Expand Down

0 comments on commit d46825b

Please sign in to comment.