Skip to content

Commit

Permalink
reenable PriorContext for Optimization
Browse files Browse the repository at this point in the history
Co-authored-by: Markus Hauru <[email protected]>
  • Loading branch information
alyst and mhauru committed May 30, 2024
1 parent 39f5d5b commit e7c7246
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
14 changes: 10 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext
context::C

function OptimizationContext{C}(context::C) where {C<:AbstractContext}
if !(context isa Union{DefaultContext,LikelihoodContext})
throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` (given: `$(typeof(context)))`"))
if !(context isa Union{DefaultContext,LikelihoodContext,PriorContext})
throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, and `DynamicPPL.PriorContext` (given: `$(typeof(context)))`"))
end
return new{C}(context)
end
Expand All @@ -55,7 +55,7 @@ DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()
# assume
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
r = vi[vn, dist]
lp = if ctx.context isa DefaultContext
lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext)
# MAP
Distributions.logpdf(dist, r)
else
Expand All @@ -73,7 +73,7 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
# affect anything.
# TODO: Stop using `get_and_set_val!`.
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
lp = if ctx.context isa DefaultContext
lp = if (ctx.context isa DefaultContext) || (ctx.context isa PriorContext)
# MAP
_loglikelihood(right, r)
else
Expand All @@ -83,6 +83,12 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns,
return r, lp, vi
end

DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, args...) =
DynamicPPL.tilde_observe(ctx.context, args...)

DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, args...) =
DynamicPPL.dot_tilde_observe(ctx.context, args...)

"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
Expand Down
28 changes: 22 additions & 6 deletions test/optimisation/OptimInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,18 @@ end
mu = x*beta
y ~ MvNormal(mu, I)
end

Random.seed!(987)
true_beta = [1.0, -2.2]
x = rand(40, 2)
y = x*true_beta

model = regtest(x, y)
mle = optimize(model, MLE())

vcmat = inv(x'x)
vcmat_mle = informationmatrix(mle).array

@test isapprox(mle.values.array, true_beta)
@test isapprox(vcmat, vcmat_mle)
end
Expand All @@ -103,10 +103,10 @@ end
@model function dot_gdemo(x)
s ~ InverseGamma(2,3)
m ~ Normal(0, sqrt(s))

(.~)(x, Normal(m, sqrt(s)))
end

model_dot = dot_gdemo([1.5, 2.0])

mle1 = optimize(gdemo_default, MLE())
Expand Down Expand Up @@ -189,6 +189,22 @@ end
x = 1.0
w = [1.0]

@testset "Default, Likelihood, Prior Contexts" begin
m1 = model1(x)
defctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
llhctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
prictx = Turing.OptimizationContext(DynamicPPL.PriorContext())
a = [0.3]

@test Turing.OptimLogDensity(m1, defctx)(a) ==
Turing.OptimLogDensity(m1, llhctx)(a) +
Turing.OptimLogDensity(m1, prictx)(a)

# test that PriorContext is calculating the right thing
@test Turing.OptimLogDensity(m1, prictx)([0.3]) -Distributions.logpdf(Uniform(0, 2), 0.3)
@test Turing.OptimLogDensity(m1, prictx)([-0.3]) -Distributions.logpdf(Uniform(0, 2), -0.3)
end

@testset "With ConditionContext" begin
m1 = model1(x)
m2 = model2() | (x = x,)
Expand Down

0 comments on commit e7c7246

Please sign in to comment.