diff --git a/ext/JuliaBUGSMCMCChainsExt.jl b/ext/JuliaBUGSMCMCChainsExt.jl index 8def70a38..8b1a1a12c 100644 --- a/ext/JuliaBUGSMCMCChainsExt.jl +++ b/ext/JuliaBUGSMCMCChainsExt.jl @@ -84,7 +84,7 @@ function JuliaBUGS.gen_chains( g = model.g generated_vars = find_generated_vars(g) - generated_vars = [v for v in model.sorted_nodes if v in generated_vars] # keep the order + generated_vars = [v for v in model.eval_cache.sorted_nodes if v in generated_vars] # keep the order param_vals = [] generated_quantities = [] diff --git a/src/gibbs.jl b/src/gibbs.jl index 4d76dcc00..ab447d2d7 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -13,7 +13,7 @@ abstract type AbstractGibbsState end struct GibbsState{T,S,C} <: AbstractGibbsState values::T conditioning_schedule::S - sorted_nodes_cache::C + cached_eval_caches::C end ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x @@ -25,17 +25,17 @@ function AbstractMCMC.step( model=l_model.logdensity, kwargs..., ) where {N,S} - sorted_nodes_cache, conditioning_schedule = OrderedDict(), OrderedDict() + cached_eval_caches, conditioning_schedule = OrderedDict(), OrderedDict() for variable_group in keys(sampler.sampler_map) variable_to_condition_on = setdiff(model.parameters, ensure_vector(variable_group)) conditioning_schedule[variable_to_condition_on] = sampler.sampler_map[variable_group] conditioned_model = AbstractPPL.condition( model, variable_to_condition_on, model.evaluation_env ) - sorted_nodes_cache[variable_to_condition_on] = conditioned_model.sorted_nodes + cached_eval_caches[variable_to_condition_on] = conditioned_model.eval_cache end param_values = JuliaBUGS.getparams(model) - return param_values, GibbsState(param_values, conditioning_schedule, sorted_nodes_cache) + return param_values, GibbsState(param_values, conditioning_schedule, cached_eval_caches) end function AbstractMCMC.step( @@ -50,12 +50,12 @@ function AbstractMCMC.step( for vs in keys(state.conditioning_schedule) model = initialize!(model, param_values) cond_model = AbstractPPL.condition( - model, vs, model.evaluation_env, state.sorted_nodes_cache[vs] + model, vs, model.evaluation_env, state.cached_eval_caches[vs] ) param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs]) end return param_values, - GibbsState(param_values, state.conditioning_schedule, state.sorted_nodes_cache) + GibbsState(param_values, state.conditioning_schedule, state.cached_eval_caches) end function gibbs_internal end diff --git a/src/model.jl b/src/model.jl index 8445b9a92..b3400f4bc 100644 --- a/src/model.jl +++ b/src/model.jl @@ -3,14 +3,53 @@ # instead of https://github.com/TuringLang/AbstractMCMC.jl/blob/d7c549fe41a80c1f164423c7ac458425535f624b/src/logdensityproblems.jl#L90 abstract type AbstractBUGSModel end +""" + EvalCache{TNF,TNA,TV} + +Pre-compute the values of the nodes in the model to avoid lookups from MetaGraph. +""" +struct EvalCache{TNF,TNA,TV} + sorted_nodes::Vector{<:VarName} + is_stochastic_vals::Vector{Bool} + is_observed_vals::Vector{Bool} + node_function_vals::TNF + node_args_vals::TNA + loop_vars_vals::TV +end + +function EvalCache(sorted_nodes::Vector{<:VarName}, g::BUGSGraph) + is_stochastic_vals = Array{Bool}(undef, length(sorted_nodes)) + is_observed_vals = Array{Bool}(undef, length(sorted_nodes)) + node_function_vals = [] + node_args_vals = [] + loop_vars_vals = [] + for (i, vn) in enumerate(sorted_nodes) + (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] + is_stochastic_vals[i] = is_stochastic + is_observed_vals[i] = is_observed + push!(node_function_vals, node_function) + push!(node_args_vals, Val(node_args)) + push!(loop_vars_vals, loop_vars) + end + return EvalCache( + sorted_nodes, + is_stochastic_vals, + is_observed_vals, + node_function_vals, + node_args_vals, + loop_vars_vals, + ) +end + """ BUGSModel The `BUGSModel` object is used for inference and represents the output of compilation. It implements the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface. """ -struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple} <: - AbstractBUGSModel +struct BUGSModel{ + base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TNA,TV +} <: AbstractBUGSModel " Indicates whether the model parameters are in the transformed space. " transformed::Bool @@ -27,8 +66,8 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple} evaluation_env::T "A vector containing the names of the model parameters (unobserved stochastic variables)." parameters::Vector{<:VarName} - "A vector containing the names of all the variables in the model, sorted in topological order." - sorted_nodes::Vector{<:VarName} + "An `EvalCache` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed." + eval_cache::EvalCache{TNF,TNA,TV} "An instance of `BUGSGraph`, representing the dependency graph of the model." g::BUGSGraph @@ -144,7 +183,7 @@ function BUGSModel( transformed_var_lengths, evaluation_env, parameters, - sorted_nodes, + EvalCache(sorted_nodes, g), g, nothing, ) @@ -152,6 +191,7 @@ end function BUGSModel( model::BUGSModel, + g::BUGSGraph, parameters::Vector{<:VarName}, sorted_nodes::Vector{<:VarName}, evaluation_env::NamedTuple=model.evaluation_env, @@ -164,8 +204,8 @@ function BUGSModel( model.transformed_var_lengths, evaluation_env, parameters, - sorted_nodes, - model.g, + EvalCache(sorted_nodes, g), + g, isnothing(model.base_model) ? model : model.base_model, ) end @@ -177,9 +217,13 @@ Initialize the model with a NamedTuple of initial values, the values are expecte """ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) - for vn in model.sorted_nodes - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn] - args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars) + for (i, vn) in enumerate(model.eval_cache.sorted_nodes) + is_stochastic = model.eval_cache.is_stochastic_vals[i] + is_observed = model.eval_cache.is_observed_vals[i] + node_function = model.eval_cache.node_function_vals[i] + node_args = model.eval_cache.node_args_vals[i] + loop_vars = model.eval_cache.loop_vars_vals[i] + args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) if !is_stochastic value = Base.invokelatest(node_function; args...) BangBang.@set!! model.evaluation_env = setindex!!( @@ -318,11 +362,11 @@ function AbstractPPL.condition( new_parameters = setdiff(model.parameters, var_group) sorted_blanket_with_vars = if sorted_nodes isa Nothing - sorted_nodes + model.eval_cache.sorted_nodes else filter( vn -> vn in union(markov_blanket(model.g, new_parameters), new_parameters), - model.sorted_nodes, + model.eval_cache.sorted_nodes, ) end @@ -338,7 +382,9 @@ function AbstractPPL.condition( end end - new_model = BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env) + new_model = BUGSModel( + model, g, new_parameters, sorted_blanket_with_vars, evaluation_env + ) return BangBang.setproperty!!(new_model, :g, g) end @@ -347,18 +393,19 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) base_model = model.base_model isa Nothing ? model : model.base_model new_parameters = [ - v for v in base_model.sorted_nodes if v in union(model.parameters, var_group) + v for + v in base_model.eval_cache.sorted_nodes if v in union(model.parameters, var_group) ] # keep the order markov_blanket_with_vars = union( markov_blanket(base_model.g, new_parameters), new_parameters ) sorted_blanket_with_vars = filter( - vn -> vn in markov_blanket_with_vars, base_model.sorted_nodes + vn -> vn in markov_blanket_with_vars, base_model.eval_cache.sorted_nodes ) new_model = BUGSModel( - model, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env + model, model.g, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env ) evaluate_env, _ = evaluate!!(new_model) return BangBang.setproperty!!(new_model, :evaluation_env, evaluate_env) @@ -374,12 +421,15 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) end function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) - (; evaluation_env, g, sorted_nodes) = model + (; evaluation_env, g) = model vi = deepcopy(evaluation_env) logp = 0.0 - for vn in sorted_nodes - (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) + for (i, vn) in enumerate(model.eval_cache.sorted_nodes) + is_stochastic = model.eval_cache.is_stochastic_vals[i] + node_function = model.eval_cache.node_function_vals[i] + node_args = model.eval_cache.node_args_vals[i] + loop_vars = model.eval_cache.loop_vars_vals[i] + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -394,11 +444,14 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) end function AbstractPPL.evaluate!!(model::BUGSModel) - (; sorted_nodes, g, evaluation_env) = model logp = 0.0 - for vn in sorted_nodes - (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) + evaluation_env = deepcopy(model.evaluation_env) + for (i, vn) in enumerate(model.eval_cache.sorted_nodes) + is_stochastic = model.eval_cache.is_stochastic_vals[i] + node_function = model.eval_cache.node_function_vals[i] + node_args = model.eval_cache.node_args_vals[i] + loop_vars = model.eval_cache.loop_vars_vals[i] + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -428,13 +481,16 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect model.untransformed_var_lengths end - g = model.g evaluation_env = deepcopy(model.evaluation_env) current_idx = 1 logp = 0.0 - for vn in model.sorted_nodes - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) + for (i, vn) in enumerate(model.eval_cache.sorted_nodes) + is_stochastic = model.eval_cache.is_stochastic_vals[i] + is_observed = model.eval_cache.is_observed_vals[i] + node_function = model.eval_cache.node_function_vals[i] + node_args = model.eval_cache.node_args_vals[i] + loop_vars = model.eval_cache.loop_vars_vals[i] + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) diff --git a/test/graphs.jl b/test/graphs.jl index ab7b56582..bfb9dae12 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -47,11 +47,11 @@ c = @varname c cond_model = AbstractPPL.condition(model, setdiff(model.parameters, [c])) # tests for MarkovBlanketBUGSModel constructor @test cond_model.parameters == [c] -@test Set(Symbol.(cond_model.sorted_nodes)) == Set([:l, :a, :b, :f, :c]) +@test Set(Symbol.(cond_model.eval_cache.sorted_nodes)) == Set([:l, :a, :b, :f, :c]) decond_model = AbstractPPL.decondition(cond_model, [a, l]) @test Set(Symbol.(decond_model.parameters)) == Set([:a, :c, :l]) -@test Set(Symbol.(decond_model.sorted_nodes)) == +@test Set(Symbol.(decond_model.eval_cache.sorted_nodes)) == Set([:l, :b, :f, :a, :d, :e, :c, :h, :g, :i]) c_value = 4.0