Skip to content

Commit

Permalink
reenable PriorContext for Optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Stukalov committed Feb 3, 2024
1 parent 39f5d5b commit c9d8a9f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext
context::C

function OptimizationContext{C}(context::C) where {C<:AbstractContext}
if !(context isa Union{DefaultContext,LikelihoodContext})
throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` (given: `$(typeof(context)))`"))
if !(context isa Union{DefaultContext,LikelihoodContext,PriorContext})
throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, and `DynamicPPL.PriorContext` (given: `$(typeof(context)))`"))
end
return new{C}(context)
end
Expand All @@ -55,7 +55,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()
# assume
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
r = vi[vn, dist]
lp = if ctx.context isa DefaultContext
lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext)
# MAP
Distributions.logpdf(dist, r)
else
Expand All @@ -73,7 +73,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
# affect anything.
# TODO: Stop using `get_and_set_val!`.
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
lp = if ctx.context isa DefaultContext
lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext)
# MAP
_loglikelihood(right, r)
else
Expand All @@ -83,6 +83,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
return r, lp, vi
end

DynamicPPL.tilde_observe(::OptimizationContext{<:PriorContext}, right, left, vi) =
(0, vi)

DynamicPPL.dot_tilde_observe(::OptimizationContext{<:PriorContext}, right, left, vi) =
(0, vi)

"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
Expand Down
12 changes: 12 additions & 0 deletions test/optimisation/OptimInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ end
x = 1.0
w = [1.0]

@testset "Default, Likelihood, Prior Contexts" begin
m1 = model1(x)
defctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
llhctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
prictx = Turing.OptimizationContext(DynamicPPL.PriorContext())
a = [0.3]

@test Turing.OptimLogDensity(m1, defctx)(a) ==
Turing.OptimLogDensity(m1, llhctx)(a) +
Turing.OptimLogDensity(m1, prictx)(a)
end

@testset "With ConditionContext" begin
m1 = model1(x)
m2 = model2() | (x = x,)
Expand Down

0 comments on commit c9d8a9f

Please sign in to comment.