diff --git a/Project.toml b/Project.toml index 523c075e..fb7080ef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.7.2" +version = "0.7.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/model.jl b/src/model.jl index 2607eaa9..b843801b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -539,3 +539,67 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect end return evaluation_env, logp end + +""" + _tempered_evaluate!!(model::BUGSModel, flattened_values::AbstractVector; temperature=1.0) + +Evaluating the model with the given model parameter values, returns updated evaluation environment +and a NamedTuple of logprior, loglikelihood and tempered logjoint (where tempered logjoint is the logjoint +whose loglikelihood component scaled by the given temperature). +""" +function _tempered_evaluate!!( + model::BUGSModel, flattened_values::AbstractVector; temperature=1.0 +) + var_lengths = if model.transformed + model.transformed_var_lengths + else + model.untransformed_var_lengths + end + + evaluation_env = deepcopy(model.evaluation_env) + current_idx = 1 + logprior, loglikelihood = 0.0, 0.0 + for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) + is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] + is_observed = model.flattened_graph_node_data.is_observed_vals[i] + node_function = model.flattened_graph_node_data.node_function_vals[i] + loop_vars = model.flattened_graph_node_data.loop_vars_vals[i] + if !is_stochastic + value = node_function(evaluation_env, loop_vars) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) + else + dist = node_function(evaluation_env, loop_vars) + if !is_observed + l = var_lengths[vn] + if model.transformed + b = Bijectors.bijector(dist) + b_inv = Bijectors.inverse(b) + reconstructed_value = reconstruct( + b_inv, + dist, + view(flattened_values, current_idx:(current_idx + l - 1)), + ) + value, logjac = Bijectors.with_logabsdet_jacobian( + b_inv, reconstructed_value + ) + else + value = reconstruct( + dist, view(flattened_values, current_idx:(current_idx + l - 1)) + ) + logjac = 0.0 + end + current_idx += l + logprior += logpdf(dist, value) + logjac + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) + else + loglikelihood += logpdf(dist, AbstractPPL.get(evaluation_env, vn)) + end + end + end + return evaluation_env, + ( + logprior=logprior, + loglikelihood=loglikelihood, + tempered_logjoint=logprior + temperature * loglikelihood, + ) +end diff --git a/test/model.jl b/test/model.jl new file mode 100644 index 00000000..a062a55c --- /dev/null +++ b/test/model.jl @@ -0,0 +1,49 @@ +@testset "logprior and loglikelihood" begin + @testset "Complex model with transformations" begin + model_def = @bugs begin + s[1] ~ InverseGamma(2, 3) + s[2] ~ InverseGamma(2, 3) + m[1] ~ Normal(0, sqrt(s[1])) + m[2] ~ Normal(0, sqrt(s[2])) + x[1:2] ~ MvNormal(m[1:2], Diagonal(s[1:2])) + end + + data = (; x=[1.0, 2.0]) + + model = compile(model_def, data) + + params = rand(4) + + b = Bijectors.bijector(InverseGamma(2, 3)) + b_inv = Bijectors.inverse(b) + + log_prior_true = begin + # parameter sorted: s[2], m[2], s[1], m[1] + s1_inversed, logjac1 = Bijectors.with_logabsdet_jacobian(b_inv, params[3]) + s2_inversed, logjac2 = Bijectors.with_logabsdet_jacobian(b_inv, params[1]) + logpdf(InverseGamma(2, 3), s1_inversed) + + logjac1 + + logpdf(InverseGamma(2, 3), s2_inversed) + + logjac2 + + logpdf(Normal(0, sqrt(s1_inversed)), params[4]) + + logpdf(Normal(0, sqrt(s2_inversed)), params[2]) + end + + log_likelihood_true = begin + s1_inversed = b_inv(params[3]) + s2_inversed = b_inv(params[1]) + logpdf( + MvNormal([params[4], params[2]], Diagonal([s1_inversed, s2_inversed])), + data.x, + ) + end + + _, (logprior, loglikelihood, tempered_logjoint) = JuliaBUGS._tempered_evaluate!!( + model, params; temperature=2.0 + ) + + @test logprior ≈ log_prior_true rtol = 1E-6 + @test loglikelihood ≈ log_likelihood_true rtol = 1E-6 + @test tempered_logjoint ≈ log_prior_true + 2.0 * log_likelihood_true rtol = 1E-6 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ceccfae5..fab36c47 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,6 +60,7 @@ end if test_group == "log_density" || test_group == "all" include("log_density.jl") + include("model.jl") end if test_group == "gibbs" || test_group == "all"