diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 9f1937ae4a..ffa94a3768 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -41,8 +41,8 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext context::C function OptimizationContext{C}(context::C) where {C<:AbstractContext} - if !(context isa Union{DefaultContext,LikelihoodContext}) - throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` (given: `$(typeof(context)))`")) + if !(context isa Union{DefaultContext,LikelihoodContext,PriorContext}) + throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, and `DynamicPPL.PriorContext` (given: `$(typeof(context)))`")) end return new{C}(context) end @@ -55,7 +55,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() # assume function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) r = vi[vn, dist] - lp = if ctx.context isa DefaultContext + lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext) # MAP Distributions.logpdf(dist, r) else @@ -73,7 +73,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, # affect anything. # TODO: Stop using `get_and_set_val!`. r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior()) - lp = if ctx.context isa DefaultContext + lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext) # MAP _loglikelihood(right, r) else @@ -83,6 +83,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, return r, lp, vi end +DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, args...) = + DynamicPPL.tilde_observe(ctx.context, args...) + +DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, args...) = + DynamicPPL.dot_tilde_observe(ctx.context, args...) + """ OptimLogDensity{M<:Model,C<:Context,V<:VarInfo} diff --git a/test/optimisation/OptimInterface.jl b/test/optimisation/OptimInterface.jl index 919de9702c..611d49a068 100644 --- a/test/optimisation/OptimInterface.jl +++ b/test/optimisation/OptimInterface.jl @@ -83,18 +83,18 @@ end mu = x*beta y ~ MvNormal(mu, I) end - + Random.seed!(987) true_beta = [1.0, -2.2] x = rand(40, 2) y = x*true_beta - + model = regtest(x, y) mle = optimize(model, MLE()) - + vcmat = inv(x'x) vcmat_mle = informationmatrix(mle).array - + @test isapprox(mle.values.array, true_beta) @test isapprox(vcmat, vcmat_mle) end @@ -103,10 +103,10 @@ end @model function dot_gdemo(x) s ~ InverseGamma(2,3) m ~ Normal(0, sqrt(s)) - + (.~)(x, Normal(m, sqrt(s))) end - + model_dot = dot_gdemo([1.5, 2.0]) mle1 = optimize(gdemo_default, MLE()) @@ -189,6 +189,22 @@ end x = 1.0 w = [1.0] + @testset "Default, Likelihood, Prior Contexts" begin + m1 = model1(x) + defctx = Turing.OptimizationContext(DynamicPPL.DefaultContext()) + llhctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext()) + prictx = Turing.OptimizationContext(DynamicPPL.PriorContext()) + a = [0.3] + + @test Turing.OptimLogDensity(m1, defctx)(a) == + Turing.OptimLogDensity(m1, llhctx)(a) + + Turing.OptimLogDensity(m1, prictx)(a) + + # test that PriorContext is calculating the right thing + @test Turing.OptimLogDensity(m1, prictx)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) + @test Turing.OptimLogDensity(m1, prictx)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) + end + @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x = x,)