From d46825be2f9e93c234a8dfd319933f0630762e50 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Tue, 4 Jun 2024 22:22:57 -0700 Subject: [PATCH] reenable PriorContext for Optimization Co-authored-by: Markus Hauru --- src/optimisation/Optimisation.jl | 18 ++++++++++++------ test/optimisation/Optimisation.jl | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index faa8a38f3..7a700f241 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -42,11 +42,11 @@ struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.Abstract context::C function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext} - if !(context isa Union{DynamicPPL.DefaultContext,DynamicPPL.LikelihoodContext}) + if !(context isa Union{DynamicPPL.DefaultContext,DynamicPPL.LikelihoodContext,DynamicPPL.PriorContext}) msg = """ - `OptimizationContext` supports only leaf contexts of type - `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` - (given: `$(typeof(context)))` + `OptimizationContext` supports only leaf contexts of type + `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, + and `DynamicPPL.PriorContext` (given: `$(typeof(context)))` """ throw(ArgumentError(msg)) end @@ -60,7 +60,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) r = vi[vn, dist] - lp = if ctx.context isa DynamicPPL.DefaultContext + lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} # MAP Distributions.logpdf(dist, r) else @@ -83,7 +83,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, r = DynamicPPL.get_and_set_val!( Random.default_rng(), vi, vns, right, DynamicPPL.SampleFromPrior() ) - lp = if ctx.context isa DynamicPPL.DefaultContext + lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} # MAP _loglikelihood(right, r) else @@ -93,6 +93,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, return r, lp, vi end +DynamicPPL.tilde_observe(ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args...) = + DynamicPPL.tilde_observe(ctx.context, args...) + +DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args...) = + DynamicPPL.dot_tilde_observe(ctx.context, args...) + """ OptimLogDensity{M<:DynamicPPL.Model,C<:Context,V<:DynamicPPL.VarInfo} diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 5f38b6ff1..da948590d 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -77,6 +77,22 @@ @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == Turing.Optimisation.OptimLogDensity(m2, ctx)(w) end + + @testset "Default, Likelihood, Prior Contexts" begin + m1 = model1(x) + defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) + llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) + prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext()) + a = [0.3] + + @test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) == + Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) + + Turing.Optimisation.OptimLogDensity(m1, prictx)(a) + + # test that PriorContext is calculating the right thing + @test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) + @test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) + end end @numerical_testset "gdemo" begin