Skip to content

Commit

Permalink
Remove Contexts for the sack of simplicity (#229)
Browse files Browse the repository at this point in the history
The three evaluation functions for `BUGSModel` can be distinguished
without `Context`s.

This PR removes them, as needs arise, we can add them back.
  • Loading branch information
sunxd3 authored Oct 25, 2024
1 parent ce29cc1 commit f9c253b
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 58 deletions.
3 changes: 1 addition & 2 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ using AbstractMCMC
using AdvancedHMC
using AdvancedHMC: Transition, stat
using JuliaBUGS
using JuliaBUGS:
AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.BangBang
using JuliaBUGS.LogDensityProblems
Expand Down
2 changes: 1 addition & 1 deletion ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module JuliaBUGSAdvancedMHExt
using AbstractMCMC
using AdvancedMH
using JuliaBUGS
using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: BUGSModel, find_generated_vars, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
Expand Down
4 changes: 2 additions & 2 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module JuliaBUGSMCMCChainsExt

using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, evaluate!!
using JuliaBUGS.AbstractPPL
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
Expand Down Expand Up @@ -89,7 +89,7 @@ function JuliaBUGS.gen_chains(
param_vals = []
generated_quantities = []
for i in axes(samples)[1]
evaluation_env = first(evaluate!!(model, LogDensityContext(), samples[i]))
evaluation_env = first(evaluate!!(model, samples[i]))
push!(
param_vals,
[AbstractPPL.get(evaluation_env, param_var) for param_var in param_vars],
Expand Down
4 changes: 2 additions & 2 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ function gibbs_internal end

function gibbs_internal(rng::Random.AbstractRNG, cond_model::BUGSModel, ::MHFromPrior)
transformed_original = JuliaBUGS.getparams(cond_model)
values, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)
values_proposed, logp_proposed = evaluate!!(cond_model, SamplingContext())
values, logp = evaluate!!(cond_model, transformed_original)
values_proposed, logp_proposed = evaluate!!(rng, cond_model)

if logp_proposed - logp > log(rand(rng))
values = values_proposed
Expand Down
2 changes: 1 addition & 1 deletion src/logdensityproblems.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function LogDensityProblems.logdensity(model::AbstractBUGSModel, x::AbstractArray)
_, logp = evaluate!!(model, LogDensityContext(), x)
_, logp = evaluate!!(model, x)
return logp
end

Expand Down
59 changes: 16 additions & 43 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ end
Initialize the model with a vector of initial values, the values can be in transformed space if `model.transformed` is set to true.
"""
function initialize!(model::BUGSModel, initial_params::AbstractVector)
evaluation_env, _ = AbstractPPL.evaluate!!(model, LogDensityContext(), initial_params)
evaluation_env, _ = AbstractPPL.evaluate!!(model, initial_params)
return BangBang.setproperty!!(model, :evaluation_env, evaluation_env)
end

Expand Down Expand Up @@ -260,18 +260,23 @@ function getparams(model::BUGSModel)
return param_vals
end

function getparams_as_ordereddict(model::BUGSModel)
d = OrderedDict{VarName,Any}()
"""
getparams(T::Type{<:AbstractDict}, model::BUGSModel)
Extract the parameter values from the model into a dictionary of type T.
If model.transformed is true, returns parameters in transformed space.
"""
function getparams(T::Type{<:AbstractDict}, model::BUGSModel)
d = T()
for v in model.parameters
value = AbstractPPL.get(model.evaluation_env, v)
if !model.transformed
d[v] = AbstractPPL.get(model.evaluation_env, v)
d[v] = value
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
d[v] = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
)
d[v] = Bijectors.transform(Bijectors.bijector(dist), value)
end
end
return d
Expand Down Expand Up @@ -355,7 +360,7 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
new_model = BUGSModel(
model, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env
)
evaluate_env, _ = evaluate!!(new_model, DefaultContext())
evaluate_env, _ = evaluate!!(new_model)
return BangBang.setproperty!!(new_model, :evaluation_env, evaluate_env)
end

Expand All @@ -368,33 +373,7 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel)
)
end

"""
DefaultContext
Use values in varinfo to compute the log joint density.
"""
struct DefaultContext <: AbstractPPL.AbstractContext end

"""
SamplingContext
Do an ancestral sampling of the model parameters. Also accumulate log joint density.
"""
@kwdef struct SamplingContext{T<:Random.AbstractRNG} <: AbstractPPL.AbstractContext
rng::T = Random.default_rng()
end

"""
LogDensityContext
Use the given values to compute the log joint density.
"""
struct LogDensityContext <: AbstractPPL.AbstractContext end

function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG)
return evaluate!!(model, SamplingContext(rng))
end
function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
(; evaluation_env, g, sorted_nodes) = model
vi = deepcopy(evaluation_env)
logp = 0.0
Expand All @@ -406,7 +385,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
evaluation_env = setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
value = rand(ctx.rng, dist) # just sample from the prior
value = rand(rng, dist) # just sample from the prior
logp += logpdf(dist, value)
evaluation_env = setindex!!(evaluation_env, value, vn)
end
Expand All @@ -415,11 +394,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
end

function AbstractPPL.evaluate!!(model::BUGSModel)
return AbstractPPL.evaluate!!(model, DefaultContext())
end
function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
(; sorted_nodes, g, evaluation_env) = model
vi = deepcopy(evaluation_env)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
Expand All @@ -446,9 +421,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
return evaluation_env, logp
end

function AbstractPPL.evaluate!!(
model::BUGSModel, ::LogDensityContext, flattened_values::AbstractVector
)
function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
var_lengths = if model.transformed
model.transformed_var_lengths
else
Expand Down
8 changes: 2 additions & 6 deletions test/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ mb_logp = begin
end

# order: b, l, c, a
@test mb_logp evaluate!!(cond_model, JuliaBUGS.LogDensityContext(), [c_value])[2] rtol =
1e-8
@test mb_logp evaluate!!(cond_model, [c_value])[2] rtol = 1e-8

# test LogDensityContext
@test begin
logp = 0
logp += logpdf(dnorm(1.0, 3.0), 1.0) # a, where f = 1.0
Expand All @@ -80,9 +78,7 @@ end
logp += logpdf(dnorm(2.0, 1.0), 4.0) # d, where g = 2.0
logp += logpdf(dnorm(4.0, 4.0), 5.0) # e, where h = 4.0
logp
end evaluate!!(
model, JuliaBUGS.LogDensityContext(), [-2.0, 4.0, 3.0, 2.0, 1.0, 4.0, 5.0]
)[2] atol = 1e-8
end evaluate!!(model, [-2.0, 4.0, 3.0, 2.0, 1.0, 4.0, 5.0])[2] atol = 1e-8

# AuxiliaryNodeInfo
test_model = @bugs begin
Expand Down
2 changes: 1 addition & 1 deletion test/log_density.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: make this available in JuliaBUGS
function _logjoint(model::JuliaBUGS.BUGSModel)
return JuliaBUGS.evaluate!!(model, JuliaBUGS.DefaultContext())[2]
return JuliaBUGS.evaluate!!(model)[2]
end

@testset "Log density" begin
Expand Down

0 comments on commit f9c253b

Please sign in to comment.