diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 9f1937ae4a..350054bf30 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -41,8 +41,8 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext context::C function OptimizationContext{C}(context::C) where {C<:AbstractContext} - if !(context isa Union{DefaultContext,LikelihoodContext}) - throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` (given: `$(typeof(context)))`")) + if !(context isa Union{DefaultContext,LikelihoodContext,PriorContext}) + throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, and `DynamicPPL.PriorContext` (given: `$(typeof(context)))`")) end return new{C}(context) end @@ -55,7 +55,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() # assume function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) r = vi[vn, dist] - lp = if ctx.context isa DefaultContext + lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext) # MAP Distributions.logpdf(dist, r) else @@ -73,7 +73,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, # affect anything. # TODO: Stop using `get_and_set_val!`. r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior()) - lp = if ctx.context isa DefaultContext + lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext) # MAP _loglikelihood(right, r) else @@ -83,6 +83,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, return r, lp, vi end +DynamicPPL.tilde_observe(::OptimizationContext{<:PriorContext}, right, left, vi) = + (0, vi) + +DynamicPPL.dot_tilde_observe(::OptimizationContext{<:PriorContext}, right, left, vi) = + (0, vi) + """ OptimLogDensity{M<:Model,C<:Context,V<:VarInfo} diff --git a/test/optimisation/OptimInterface.jl b/test/optimisation/OptimInterface.jl index 919de9702c..34e941d9f7 100644 --- a/test/optimisation/OptimInterface.jl +++ b/test/optimisation/OptimInterface.jl @@ -189,6 +189,18 @@ end x = 1.0 w = [1.0] + @testset "Default, Likelihood, Prior Contexts" begin + m1 = model1(x) + defctx = Turing.OptimizationContext(DynamicPPL.DefaultContext()) + llhctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext()) + prictx = Turing.OptimizationContext(DynamicPPL.PriorContext()) + a = [0.3] + + @test Turing.OptimLogDensity(m1, defctx)(a) == + Turing.OptimLogDensity(m1, llhctx)(a) + + Turing.OptimLogDensity(m1, prictx)(a) + end + @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x = x,)