diff --git a/Project.toml b/Project.toml index bd7c6ec7b..aad80bd58 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" version = "0.6.2" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -41,7 +42,7 @@ JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"] JuliaBUGSDynamicPPLExt = ["DynamicPPL"] JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"] JuliaBUGSGraphPlotExt = ["GraphPlot"] -JuliaBUGSMCMCChainsExt = ["DynamicPPL", "MCMCChains"] +JuliaBUGSMCMCChainsExt = ["MCMCChains"] [compat] ADTypes = "1.6" diff --git a/ext/JuliaBUGSAdvancedHMCExt.jl b/ext/JuliaBUGSAdvancedHMCExt.jl index 0d3503949..96632be06 100644 --- a/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/ext/JuliaBUGSAdvancedHMCExt.jl @@ -1,30 +1,22 @@ module JuliaBUGSAdvancedHMCExt -using AbstractMCMC -using AdvancedHMC -using AdvancedHMC: Transition, stat -using JuliaBUGS +using AbstractMCMC: AbstractMCMC +using AdvancedHMC: AdvancedHMC +using MCMCChains: MCMCChains using JuliaBUGS: - AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!! -using JuliaBUGS.BUGSPrimitives -using JuliaBUGS.BangBang -using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD -using JuliaBUGS.Bijectors -using JuliaBUGS.Random -using MCMCChains: Chains -import JuliaBUGS: gibbs_internal + JuliaBUGS, Accessors, ADTypes, LogDensityProblems, LogDensityProblemsAD, Random function AbstractMCMC.bundle_samples( - ts::Vector{<:Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, - sampler::AdvancedHMC.AbstractHMCSampler, + ts::Vector{<:AdvancedHMC.Transition}, + logdensitymodel, + sampler, state, - chain_type::Type{Chains}; + chain_type::Type{MCMCChains.Chains}; discard_initial=0, thinning=1, kwargs..., ) + params = [t.z.θ for t in ts] stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) stats_values = [ vcat([ts[i].z.ℓπ.value..., collect(values(AdvancedHMC.stat(ts[i])))...]) for @@ -33,7 +25,7 @@ function AbstractMCMC.bundle_samples( return JuliaBUGS.gen_chains( logdensitymodel, - [t.z.θ for t in ts], + params, stats_names, stats_values; discard_initial=discard_initial, @@ -43,24 +35,23 @@ function AbstractMCMC.bundle_samples( end function JuliaBUGS.gibbs_internal( - rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC + rng::Random.AbstractRNG, + sub_model::JuliaBUGS.BUGSModel, + sampler::AdvancedHMC.HMC, + state::AdvancedHMC.HMCState, + adtype::ADTypes.AbstractADType, ) - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model) - ) - t, s = AbstractMCMC.step( - rng, - logdensitymodel, - sampler; - n_adapts=0, - initial_params=JuliaBUGS.getparams(cond_model), + # update the log density in the state + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, sub_model) + state = Accessors.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, state.transition.z.θ, state.transition.z.r ) - updated_model = initialize!(cond_model, t.z.θ) - return JuliaBUGS.getparams( - BangBang.setproperty!!( - updated_model.base_model, :evaluation_env, updated_model.evaluation_env - ), + + logdensitymodel = AbstractMCMC.LogDensityModel( + LogDensityProblemsAD.ADgradient(adtype, sub_model) ) + _, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0) + return initialize!(sub_model, s.transition.z.θ).evaluation_env, s end end diff --git a/ext/JuliaBUGSAdvancedMHExt.jl b/ext/JuliaBUGSAdvancedMHExt.jl index 5fb594030..4c4009b42 100644 --- a/ext/JuliaBUGSAdvancedMHExt.jl +++ b/ext/JuliaBUGSAdvancedMHExt.jl @@ -1,35 +1,30 @@ module JuliaBUGSAdvancedMHExt -using AbstractMCMC -using AdvancedMH -using JuliaBUGS -using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!! -using JuliaBUGS.BUGSPrimitives -using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD -using JuliaBUGS.Random -using JuliaBUGS.Bijectors -using MCMCChains: Chains -import JuliaBUGS: gibbs_internal +using AbstractMCMC: AbstractMCMC +using AdvancedMH: AdvancedMH +using MCMCChains: MCMCChains +using JuliaBUGS: JuliaBUGS +using JuliaBUGS: Accessors, ADTypes, LogDensityProblems, LogDensityProblemsAD, Random function AbstractMCMC.bundle_samples( - ts::Vector{<:AdvancedMH.AbstractTransition}, - logdensitymodel::Union{ - AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, - AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, - }, - sampler::AdvancedMH.MHSampler, + ts::Vector{<:AdvancedMH.Transition}, + logdensitymodel, + sampler, state, - chain_type::Type{Chains}; + chain_type::Type{MCMCChains.Chains}; discard_initial=0, thinning=1, kwargs..., ) + params = [t.params for t in ts] + stats_names = [:lp] + stats_values = [t.lp for t in ts] + return JuliaBUGS.gen_chains( logdensitymodel, - [t.params for t in ts], - [:lp], - [t.lp for t in ts]; + params, + stats_names, + stats_values; discard_initial=discard_initial, thinning=thinning, kwargs..., @@ -37,24 +32,19 @@ function AbstractMCMC.bundle_samples( end function JuliaBUGS.gibbs_internal( - rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::AdvancedMH.MHSampler + rng::Random.AbstractRNG, + sub_model::JuliaBUGS.BUGSModel, + sampler::AdvancedMH.MHSampler, + state::AdvancedMH.Transition, + adtype::ADTypes.AbstractADType, ) + state = Accessors.@set state.lp = LogDensityProblems.logdensity(sub_model, state.params) + logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model) - ) - t, s = AbstractMCMC.step( - rng, - logdensitymodel, - sampler; - n_adapts=0, - initial_params=JuliaBUGS.getparams(cond_model), - ) - updated_model = initialize!(cond_model, t.params) - return JuliaBUGS.getparams( - BangBang.setproperty!!( - updated_model.base_model, :evaluation_env, updated_model.evaluation_env - ), + LogDensityProblemsAD.ADgradient(adtype, sub_model) ) + _, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state) + return JuliaBUGS.initialize!(sub_model, s.params).evaluation_env, s end end diff --git a/ext/JuliaBUGSMCMCChainsExt.jl b/ext/JuliaBUGSMCMCChainsExt.jl index 8febc4f40..1223d4bdc 100644 --- a/ext/JuliaBUGSMCMCChainsExt.jl +++ b/ext/JuliaBUGSMCMCChainsExt.jl @@ -1,37 +1,37 @@ module JuliaBUGSMCMCChainsExt -using JuliaBUGS -using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!! -using JuliaBUGS.AbstractPPL -using JuliaBUGS.BUGSPrimitives -using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD -using DynamicPPL -using AbstractMCMC +using AbstractMCMC: AbstractMCMC using MCMCChains: Chains +using JuliaBUGS: + JuliaBUGS, AbstractPPL, BUGSPrimitives, LogDensityProblems, LogDensityProblemsAD -function JuliaBUGS.gen_chains( - model::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, - samples, - stats_names, - stats_values; +function AbstractMCMC.bundle_samples( + ts, + logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, + sampler::JuliaBUGS.Gibbs, + state, + ::Type{Chains}; discard_initial=0, - thinning=1, kwargs..., ) return JuliaBUGS.gen_chains( - model.logdensity, - samples, - stats_names, - stats_values; - discard_initial=discard_initial, - thinning=thinning, - kwargs..., + logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs... ) end +function get_bugsmodel(model::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}) + return model.logdensity +end + +function get_bugsmodel( + model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper} +) + ad_wrapper = model.logdensity + return Base.parent(ad_wrapper)::JuliaBUGS.BUGSModel +end + function JuliaBUGS.gen_chains( - model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + model::AbstractMCMC.LogDensityModel, samples, stats_names, stats_values; @@ -40,7 +40,7 @@ function JuliaBUGS.gen_chains( kwargs..., ) return JuliaBUGS.gen_chains( - model.logdensity.ℓ, + get_bugsmodel(model), samples, stats_names, stats_values; @@ -62,13 +62,15 @@ function JuliaBUGS.gen_chains( param_vars = model.parameters g = model.g - generated_vars = find_generated_vars(g) + generated_vars = JuliaBUGS.find_generated_quantities_variables(g) generated_vars = [v for v in model.sorted_nodes if v in generated_vars] # keep the order param_vals = [] generated_quantities = [] for i in axes(samples)[1] - evaluation_env = first(evaluate!!(model, LogDensityContext(), samples[i])) + evaluation_env = first( + JuliaBUGS.evaluate!!(model, JuliaBUGS.LogDensityContext(), samples[i]) + ) push!( param_vals, [AbstractPPL.get(evaluation_env, param_var) for param_var in param_vars], @@ -84,13 +86,13 @@ function JuliaBUGS.gen_chains( param_name_leaves = collect( Iterators.flatten([ - collect(DynamicPPL.varname_leaves(vn, param_vals[1][i])) for + collect(varname_leaves(vn, param_vals[1][i])) for (i, vn) in enumerate(param_vars) ],), ) generated_varname_leaves = collect( Iterators.flatten([ - collect(DynamicPPL.varname_leaves(vn, generated_quantities[1][i])) for + collect(varname_leaves(vn, generated_quantities[1][i])) for (i, vn) in enumerate(generated_vars) ],), ) @@ -129,4 +131,28 @@ function JuliaBUGS.gen_chains( ) end +# utils: copied from DynamicPPL + +varname_leaves(vn::JuliaBUGS.VarName, ::Real) = [vn] +function varname_leaves(vn::JuliaBUGS.VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + JuliaBUGS.VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for + I in CartesianIndices(val) + ) +end +function varname_leaves(vn::JuliaBUGS.VarName, val::AbstractArray) + return Iterators.flatten( + varname_leaves( + JuliaBUGS.VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] + ) for I in CartesianIndices(val) + ) +end +function varname_leaves(vn::JuliaBUGS.VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do sym + optic = Accessors.PropertyLens{sym}() + varname_leaves(JuliaBUGS.VarName(vn, optic ∘ getoptic(vn)), optic(val)) + end + return Iterators.flatten(iter) +end + end diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 8071b5362..c860afaa5 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -3,6 +3,7 @@ module JuliaBUGS using AbstractMCMC using AbstractPPL using Accessors +using ADTypes using BangBang using Bijectors: Bijectors using Distributions diff --git a/src/gibbs.jl b/src/gibbs.jl index 12da4b5b4..e4110b254 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -1,89 +1,159 @@ -struct Gibbs{N,S} <: AbstractMCMC.AbstractSampler - sampler_map::OrderedDict{N,S} +struct Gibbs{ODT<:OrderedDict,ADT<:ADTypes.AbstractADType} <: AbstractMCMC.AbstractSampler + sampler_map::ODT + adtype::ADT end -function Gibbs(model::BUGSModel, s::AbstractMCMC.AbstractSampler) - return Gibbs(OrderedDict([v => s for v in model.parameters])) +function Gibbs(sampler_map::ODT) where {ODT<:OrderedDict} + return Gibbs(sampler_map, ADTypes.AutoReverseDiff(; compile=false)) end -struct MHFromPrior <: AbstractMCMC.AbstractSampler end +function verify_sampler_map(model::BUGSModel, sampler_map::OrderedDict) + all_variables_in_keys = Set(vcat(keys(sampler_map)...)) + model_parameters = Set(model.parameters) + + # Check for extra variables in sampler_map that are not in model parameters + extra_variables = setdiff(all_variables_in_keys, model_parameters) + if !isempty(extra_variables) + throw( + ArgumentError( + "Sampler map contains variables not in the model: $extra_variables" + ), + ) + end + + # Check for model parameters not covered by sampler_map + left_over_variables = setdiff(model_parameters, all_variables_in_keys) + if !isempty(left_over_variables) + throw( + ArgumentError( + "Some model parameters are not covered by the sampler map: $left_over_variables", + ), + ) + end + + return true +end -abstract type AbstractGibbsState end +""" + _create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::Vector{<:VarName}) -struct GibbsState{T,S,C} <: AbstractGibbsState - values::T - conditioning_schedule::S - sorted_nodes_cache::C +Internal function to create a conditioned model for Gibbs sampling. This is different from conditioning, because conditioning +only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the +Markov blanket of the variables that are being updated. +""" +function _create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName) + return _create_submodel_for_gibbs_sampling(model, [variables_to_update]) +end +function _create_submodel_for_gibbs_sampling( + model::BUGSModel, variables_to_update::NTuple{N,<:VarName} +) where {N} + return _create_submodel_for_gibbs_sampling(model, collect(variables_to_update)) +end +function _create_submodel_for_gibbs_sampling( + model::BUGSModel, variables_to_update::Vector{<:VarName} +) + _markov_blanket = markov_blanket(model.g, variables_to_update) + mb_without_variables_to_update = setdiff(_markov_blanket, variables_to_update) + model_parameters_in_mb = filter( + v -> is_stochastic(model.g, v) && !is_observation(model.g, v), + mb_without_variables_to_update, + ) + sub_model = BUGSModel( + model; parameters=variables_to_update, sorted_nodes=collect(_markov_blanket) + ) + return condition(sub_model, collect(model_parameters_in_mb)) end -ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x +struct GibbsState{T,S,C} + evaluation_env::T + sub_model_cache::C + sub_states::S +end + +""" + gibbs_internal(rng, sub_model, sampler, state, adtype) + +Internal function to perform Gibbs sampling. This function should first update the +sampler state with the correct log density and then do a single step of the sampler. +It should return the `evaluation_env` and the updated sampler state. +""" +function gibbs_internal end function AbstractMCMC.step( rng::Random.AbstractRNG, - l_model::AbstractMCMC.LogDensityModel{<:BUGSModel}, - sampler::Gibbs{N,S}; - model=l_model.logdensity, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModel}, + sampler::Gibbs; + model=logdensitymodel.logdensity, kwargs..., -) where {N,S} - sorted_nodes_cache, conditioning_schedule = OrderedDict(), OrderedDict() - for variable_group in keys(sampler.sampler_map) - variable_to_condition_on = setdiff(model.parameters, ensure_vector(variable_group)) - conditioning_schedule[variable_to_condition_on] = sampler.sampler_map[variable_group] - conditioned_model = AbstractPPL.condition( - model, variable_to_condition_on, model.evaluation_env - ) - sorted_nodes_cache[variable_to_condition_on] = conditioned_model.sorted_nodes +) + verify_sampler_map(model, sampler.sampler_map) + + submodel_cache = Vector{BUGSModel}(undef, length(sampler.sampler_map)) + sub_states = Any[] + for (i, variable_group) in enumerate(keys(sampler.sampler_map)) + local_sampler = sampler.sampler_map[variable_group] + submodel = _create_submodel_for_gibbs_sampling(model, variable_group) + if local_sampler isa MHFromPrior + evaluation_env, logp = evaluate!!(submodel, DefaultContext()) + state = MHState(evaluation_env, logp) + else + sublogdensitymodel = AbstractMCMC.LogDensityModel( + LogDensityProblemsAD.ADgradient(sampler.adtype, submodel) + ) + _, state = AbstractMCMC.step(rng, sublogdensitymodel, local_sampler) + end + submodel_cache[i] = submodel + push!(sub_states, state) end - param_values = JuliaBUGS.getparams(model) - return param_values, GibbsState(param_values, conditioning_schedule, sorted_nodes_cache) + + return getparams(model), + GibbsState(model.evaluation_env, submodel_cache, map(identity, sub_states)) end function AbstractMCMC.step( rng::Random.AbstractRNG, - l_model::AbstractMCMC.LogDensityModel{<:BUGSModel}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModel}, sampler::Gibbs, - state::AbstractGibbsState; - model=l_model.logdensity, + state::GibbsState; + model=logdensitymodel.logdensity, kwargs..., ) - param_values = state.values - for vs in keys(state.conditioning_schedule) - model = initialize!(model, param_values) - cond_model = AbstractPPL.condition( - model, vs, model.evaluation_env, state.sorted_nodes_cache[vs] + evaluation_env = state.evaluation_env + for (i, vs) in enumerate(keys(sampler.sampler_map)) + sub_model = BangBang.setproperty!!( + state.sub_model_cache[i], :evaluation_env, evaluation_env + ) + evaluation_env, new_sub_state = gibbs_internal( + rng, sub_model, sampler.sampler_map[vs], state.sub_states[i], sampler.adtype ) - param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs]) + state.sub_states[i] = new_sub_state end - return param_values, - GibbsState(param_values, state.conditioning_schedule, state.sorted_nodes_cache) + model = BangBang.setproperty!!(model, :evaluation_env, evaluation_env) + return getparams(model), + GibbsState(evaluation_env, state.sub_model_cache, state.sub_states) end -function gibbs_internal end +struct MHFromPrior <: AbstractMCMC.AbstractSampler end -function gibbs_internal(rng::Random.AbstractRNG, cond_model::BUGSModel, ::MHFromPrior) - transformed_original = JuliaBUGS.getparams(cond_model) - values, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original) - values_proposed, logp_proposed = evaluate!!(cond_model, SamplingContext()) +struct MHState{T} + evaluation_env::T + logp::Float64 +end + +function gibbs_internal( + rng::Random.AbstractRNG, + sub_model::BUGSModel, + ::MHFromPrior, + state::MHState, + adtype::ADTypes.AbstractADType, +) + evaluation_env, logp = evaluate!!(sub_model, DefaultContext()) + proposed_evaluation_env, logp_proposed = evaluate!!(sub_model, SamplingContext()) if logp_proposed - logp > log(rand(rng)) - values = values_proposed + evaluation_env = proposed_evaluation_env + logp = logp_proposed end - return JuliaBUGS.getparams( - BangBang.setproperty!!(cond_model.base_model, :evaluation_env, values) - ) -end - -function AbstractMCMC.bundle_samples( - ts, - logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, - sampler::Gibbs, - state, - ::Type{T}; - discard_initial=0, - kwargs..., -) where {T} - return JuliaBUGS.gen_chains( - logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs... - ) + return evaluation_env, MHState(evaluation_env, logp) end diff --git a/src/graphs.jl b/src/graphs.jl index 71d8e7cd3..59a726433 100644 --- a/src/graphs.jl +++ b/src/graphs.jl @@ -7,137 +7,150 @@ struct NodeInfo{F} loop_vars::NamedTuple end -""" - BUGSGraph +const BUGSGraph = MetaGraph{ + Int,Graphs.SimpleDiGraph{Int},<:VarName,<:NodeInfo,Nothing,Nothing,<:Any,Float64 +} -The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for -`MetaGraphsNext.MetaGraph`. -""" -const BUGSGraph = MetaGraph +is_stochastic(g::BUGSGraph, v::VarName) = g[v].is_stochastic +is_model_parameter(g::BUGSGraph, v::VarName) = g[v].is_stochastic && !g[v].is_observed +is_observation(g::BUGSGraph, v::VarName) = g[v].is_stochastic && g[v].is_observed +is_deterministic(g::BUGSGraph, v::VarName) = !g[v].is_stochastic """ - find_generated_vars(g::BUGSGraph) + find_generated_quantities_variables(g::BUGSGraph) -Return all the logical variables without stochastic descendants. The values of these variables -do not affect sampling process. These variables are called "generated quantities" traditionally. +Find all the generated quantities variables in the graph. + +Generated quantities variables are variables that do not affect the sampling process. +They are variables that do not have any descendant variables that are observed. """ -function find_generated_vars(g) - graph_roots = VarName[] # root nodes of the graph +function find_generated_quantities_variables( + g::MetaGraph{Int,<:SimpleDiGraph,Label,VertexData} +) where {Label,VertexData} + generated_quantities_variables = Set{Label}() + can_reach_observations = Dict{Label,Bool}() + for n in labels(g) - if isempty(outneighbor_labels(g, n)) - push!(graph_roots, n) + if !is_observation(g, n) + if !dfs_can_reach_observations(g, n, can_reach_observations) + push!(generated_quantities_variables, n) + end end end + return generated_quantities_variables +end - generated_vars = VarName[] - for n in graph_roots - if !g[n].is_stochastic - push!(generated_vars, n) # graph roots that are Logical nodes are generated variables - find_generated_vars_recursive_helper(g, n, generated_vars) - end +function dfs_can_reach_observations(g, n, can_reach_observations) + if haskey(can_reach_observations, n) + return can_reach_observations[n] end - return generated_vars -end -function find_generated_vars_recursive_helper(g, n, generated_vars) - if n in generated_vars # already visited - return nothing + if is_observation(g, n) + can_reach_observations[n] = true + return true end - for p in inneighbor_labels(g, n) # parents - if p in generated_vars # already visited - continue - end - if g[p].node_type == Stochastic - continue - end # p is a Logical Node - if !any(x -> g[x].node_type == Stochastic, outneighbor_labels(g, p)) # if the node has stochastic children, it is not a root - push!(generated_vars, p) + + can_reach = false + for child in MetaGraphsNext.outneighbor_labels(g, n) + if dfs_can_reach_observations(g, child, can_reach_observations) + can_reach = true + break end - find_generated_vars_recursive_helper(g, p, generated_vars) end + + can_reach_observations[n] = can_reach + return can_reach end """ - markov_blanket(g::BUGSModel, v) + markov_blanket(g::BUGSGraph, v) Find the Markov blanket of variable(s) `v` in graph `g`. `v` can be a single `VarName` or a vector/tuple of `VarName`. + The Markov Blanket of a variable is the set of variables that shield the variable from the rest of the -network. Effectively, the Markov blanket of a variable is the set of its parents, its children, and +network. Effectively, the Markov blanket of a variable is the set of its parents, its children, and its children's other parents (reference: https://en.wikipedia.org/wiki/Markov_blanket). -In the case of vector, the Markov Blanket is the union of the Markov Blankets of each variable -minus the variables themselves (reference: Liu, X.-Q., & Liu, X.-S. (2018). Markov Blanket and Markov -Boundary of Multiple Variables. Journal of Machine Learning Research, 19(43), 1–50.) +In the case of a vector of variables, the Markov Blanket is the union of the Markov Blankets of each variable +minus the variables themselves[1]. -In the case of M-H acceptance ratio evaluation, only the logps of the children are needed, because the logp of the parents -and co-parents are not changed (their values are still needed to compute the distributions). +This function returns a `Set` of `VarName`s, containing the Markov blanket of the variable(s) `v` in graph `g`, deterministic +variables that are on the path from `v` to their Markov blankets, and the variables `v` itself(themselves). + +[1] Liu, X.-Q., & Liu, X.-S. (2018). Markov Blanket and Markov +Boundary of Multiple Variables. Journal of Machine Learning Research, 19(43), 1–50. """ -function markov_blanket(g::BUGSGraph, v::VarName; children_only=false) - if !children_only - parents = stochastic_inneighbors(g, v) - children = stochastic_outneighbors(g, v) - co_parents = VarName[] - for p in children - co_parents = vcat(co_parents, stochastic_inneighbors(g, p)) - end - blanket = unique(vcat(parents, children, co_parents...)) - return [x for x in blanket if x != v] - else - return stochastic_outneighbors(g, v) +function markov_blanket(g::MetaGraph{Int,<:SimpleDiGraph,L,VD}, v::L) where {L,VD} + if !is_stochastic(g, v) + throw(ArgumentError("Variable $v is logical, so it has no Markov blanket.")) end -end -function markov_blanket(g::BUGSGraph, v::Vector{<:VarName}; children_only=false) - blanket = VarName[] - for vn in v - blanket = vcat(blanket, markov_blanket(g, vn; children_only=children_only)) + parents, deterministic_variables_en_route_parents = dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, v, MetaGraphsNext.inneighbor_labels + ) + children, deterministic_variables_en_route_children = dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, v, MetaGraphsNext.outneighbor_labels + ) + + co_parents, deterministic_variables_en_route_co_parents = Set{L}(), Set{L}() + for child in children + co_parents_child, deterministic_variables_en_route_co_parents_child = dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, child, MetaGraphsNext.inneighbor_labels + ) + union!(co_parents, co_parents_child) + union!( + deterministic_variables_en_route_co_parents, + deterministic_variables_en_route_co_parents_child, + ) end - return [x for x in unique(blanket) if x ∉ v] + + blanket = union!( + parents, + children, + co_parents, + deterministic_variables_en_route_parents, + deterministic_variables_en_route_children, + deterministic_variables_en_route_co_parents, + ) + push!(blanket, v) + return blanket end -""" - stochastic_neighbors(g::BUGSGraph, c::VarName, f) - -Internal function to find all the stochastic neighbors (parents or children), returns a vector of -`VarName` containing the stochastic neighbors and the logical variables along the paths. -""" -function stochastic_neighbors( - g::BUGSGraph, - v::VarName, - f::Union{ - typeof(MetaGraphsNext.inneighbor_labels),typeof(MetaGraphsNext.outneighbor_labels) - }, -) - stochastic_neighbors_vec = VarName[] - logical_en_route = VarName[] # logical variables - for u in f(g, v) - if g[u].is_stochastic - push!(stochastic_neighbors_vec, u) - else - push!(logical_en_route, u) - ns = stochastic_neighbors(g, u, f) - for n in ns - push!(stochastic_neighbors_vec, n) - end - end - end - return [stochastic_neighbors_vec..., logical_en_route...] +function markov_blanket( + g::MetaGraph{Int,<:SimpleDiGraph,L,VD}, vs::AbstractVector{<:L} +) where {L,VD} + return reduce((acc, vn) -> union!(acc, markov_blanket(g, vn)), vs; init=Set{L}()) end -""" - stochastic_inneighbors(g::BUGSGraph, v::VarName) +function dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g::MetaGraph{Int,<:SimpleDiGraph,L,VD}, v::L, neighbor_func::F +) where {L,VD,F} + stochastic_neighbors = Set{L}() + deterministic_variables_en_route = Set{L}() + stack = VarName[v] + visited = Set{L}() -Find all the stochastic inneighbors (parents) of `v`. -""" -function stochastic_inneighbors(g, v) - return stochastic_neighbors(g, v, MetaGraphsNext.inneighbor_labels) -end + while !isempty(stack) + current = pop!(stack) + current in visited && continue -""" - stochastic_outneighbors(g::BUGSGraph, v::VarName) + if is_deterministic(g, current) + push!(deterministic_variables_en_route, current) + end -Find all the stochastic outneighbors (children) of `v`. -""" -function stochastic_outneighbors(g, v) - return stochastic_neighbors(g, v, MetaGraphsNext.outneighbor_labels) + push!(visited, current) + + for neighbor in neighbor_func(g, current) + if !(neighbor in visited) + if is_stochastic(g, neighbor) + push!(stochastic_neighbors, neighbor) + # and stop (not pushing to stack) + else + push!(stack, neighbor) + end + end + end + end + + return stochastic_neighbors, deterministic_variables_en_route end diff --git a/src/model.jl b/src/model.jl index 2eed82371..f3380e1f5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -19,9 +19,9 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple} "The length of the parameters vector in the transformed (unconstrained) space." transformed_param_length::Int "A dictionary mapping the names of the variables to their lengths in the original (constrained) space." - untransformed_var_lengths::Dict{<:VarName,Int} + untransformed_var_lengths::OrderedDict{<:VarName,Int} "A dictionary mapping the names of the variables to their lengths in the transformed (unconstrained) space." - transformed_var_lengths::Dict{<:VarName,Int} + transformed_var_lengths::OrderedDict{<:VarName,Int} "A `NamedTuple` containing the values of the variables in the model, all the values are in the constrained space." evaluation_env::T @@ -94,8 +94,8 @@ function BUGSModel( sorted_nodes = VarName[label_for(g, node) for node in topological_sort(g)] parameters = VarName[] untransformed_param_length, transformed_param_length = 0, 0 - untransformed_var_lengths, transformed_var_lengths = Dict{VarName,Int}(), - Dict{VarName,Int}() + untransformed_var_lengths, transformed_var_lengths = OrderedDict{VarName,Int}(), + OrderedDict{VarName,Int}() for vn in sorted_nodes (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] @@ -129,7 +129,11 @@ function BUGSModel( rand(dist) catch e error( - "Failed to sample from the prior distribution of $vn, consider providing initialization values for $vn or it's parents: $(collect(MetaGraphsNext.inneighbor_labels(g, vn))...).", + """ + Failed to sample from the prior distribution of $vn, consider providing + initialization values for $vn or it's parents: + $(collect(MetaGraphsNext.inneighbor_labels(g, vn))...). + """, ) end evaluation_env = BangBang.setindex!!(evaluation_env, init_value, vn) @@ -151,7 +155,7 @@ function BUGSModel( end function BUGSModel( - model::BUGSModel, + model::BUGSModel; parameters::Vector{<:VarName}, sorted_nodes::Vector{<:VarName}, evaluation_env::NamedTuple=model.evaluation_env, @@ -177,40 +181,36 @@ Initialize the model with a NamedTuple of initial values, the values are expecte """ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) + evaluation_env = model.evaluation_env for vn in model.sorted_nodes - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn] + (; node_function, node_args, loop_vars) = model.g[vn] args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) - if !is_stochastic + if is_deterministic(model.g, vn) value = Base.invokelatest(node_function; args...) - BangBang.@set!! model.evaluation_env = setindex!!( - model.evaluation_env, value, vn - ) - elseif !is_observed + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) + elseif is_model_parameter(model.g, vn) initialization = try AbstractPPL.get(initial_params, vn) catch _ missing end if !ismissing(initialization) - BangBang.@set!! model.evaluation_env = setindex!!( - model.evaluation_env, initialization, vn - ) + evaluation_env = BangBang.setindex!!(evaluation_env, initialization, vn) else - BangBang.@set!! model.evaluation_env = setindex!!( - model.evaluation_env, - rand(Base.invokelatest(node_function; args...)), - vn, + evaluation_env = BangBang.setindex!!( + evaluation_env, rand(Base.invokelatest(node_function; args...)), vn ) end end end - return model + return BangBang.setproperty!!(model, :evaluation_env, evaluation_env) end """ initialize!(model::BUGSModel, initial_params::AbstractVector) -Initialize the model with a vector of initial values, the values can be in transformed space if `model.transformed` is set to true. +Initialize the model with a vector of initial values, the values can be in transformed +space if `model.transformed` is set to true. """ function initialize!(model::BUGSModel, initial_params::AbstractVector) evaluation_env, _ = AbstractPPL.evaluate!!(model, LogDensityContext(), initial_params) @@ -272,70 +272,66 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) return BangBang.setproperty!!(model, :transformed, bool) end -function AbstractPPL.condition( +function create_sub_model( model::BUGSModel, - d::Dict{<:VarName,<:Any}, - sorted_nodes=Nothing, # support cached sorted Markov blanket nodes + model_parameters_in_submodel::Vector{<:VarName}, + all_variables_in_submodel::Vector{<:VarName}, +) + return BUGSModel(model, model_parameters_in_submodel, all_variables_in_submodel) +end + +function AbstractPPL.condition( + model::BUGSModel, variables_to_condition_on_and_values::Dict{<:VarName,<:Any} ) - new_evaluation_env = deepcopy(model.evaluation_env) - for (p, value) in d - new_evaluation_env = setindex!!(new_evaluation_env, value, p) + evaluation_env = model.evaluation_env + for (variable, value) in pairs(variables_to_condition_on_and_values) + evaluation_env = BangBang.setindex!!(evaluation_env, value, variable) end return AbstractPPL.condition( - model, collect(keys(d)), new_evaluation_env; sorted_nodes=sorted_nodes + model, collect(keys(variables_to_condition_on_and_values)), evaluation_env ) end - function AbstractPPL.condition( model::BUGSModel, - var_group::Vector{<:VarName}, + variables_to_condition_on::Vector{<:VarName}, evaluation_env::NamedTuple=model.evaluation_env, - sorted_nodes=Nothing, ) - check_var_group(var_group, model) - new_parameters = setdiff(model.parameters, var_group) - - sorted_blanket_with_vars = if sorted_nodes isa Nothing - sorted_nodes - else - filter( - vn -> vn in union(markov_blanket(model.g, new_parameters), new_parameters), - model.sorted_nodes, - ) + BangBang.setproperty!!(model, :evaluation_env, evaluation_env) + for vn in variables_to_condition_on + if !model.g[vn].is_stochastic + throw( + ArgumentError( + "$vn is not a stochastic variable, conditioning on it is not supported" + ), + ) + elseif model.g[vn].is_observed + @warn "$vn is already an observed variable, conditioning on it won't have any effect" + else + new_g = copy(model.g) + new_g[vn] = BangBang.setproperty!!(model.g[vn], :is_observed, true) + model = BangBang.setproperty!!(model, :g, new_g) + end end - - return BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env) + return model end function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) - check_var_group(var_group, model) - base_model = model.base_model isa Nothing ? model : model.base_model - - new_parameters = [ - v for v in base_model.sorted_nodes if v in union(model.parameters, var_group) - ] # keep the order - - markov_blanket_with_vars = union( - markov_blanket(base_model.g, new_parameters), new_parameters - ) - sorted_blanket_with_vars = filter( - vn -> vn in markov_blanket_with_vars, base_model.sorted_nodes - ) - - new_model = BUGSModel( - model, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env - ) - evaluate_env, _ = evaluate!!(new_model, DefaultContext()) - return BangBang.setproperty!!(new_model, :evaluation_env, evaluate_env) -end - -function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) - non_vars = filter(var -> var ∉ labels(model.g), var_group) - logical_vars = filter(var -> !model.g[var].is_stochastic, var_group) - isempty(non_vars) || error("Variables $(non_vars) are not in the model") - return isempty(logical_vars) || error( - "Variables $(logical_vars) are not stochastic variables, conditioning on them is not supported", - ) + for vn in var_group + if !model.g[vn].is_stochastic + throw( + ArgumentError( + "$vn is not a stochastic variable, deconditioning it is not supported" + ), + ) + elseif !model.g[vn].is_observed + @warn "$vn is already treated as model parameter, deconditioning it won't have any effect" + else + new_g = copy(model.g) + new_g[vn] = BangBang.setproperty!!(model.g[vn], :is_observed, false) + model = BangBang.setproperty!!(model, :g, new_g) + end + end + return model end """ @@ -365,20 +361,31 @@ function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG) return evaluate!!(model, SamplingContext(rng)) end function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) - (; evaluation_env, g, sorted_nodes) = model - vi = deepcopy(evaluation_env) + evaluation_env = deepcopy(model.evaluation_env) # TODO: a lot of the arrays are not modified logp = 0.0 - for vn in sorted_nodes - (; is_stochastic, node_function, node_args, loop_vars) = g[vn] + for vn in model.sorted_nodes + (; node_function, node_args, loop_vars) = model.g[vn] args = prepare_arg_values(node_args, evaluation_env, loop_vars) - if !is_stochastic + if !is_stochastic(model.g, vn) value = node_function(; args...) - evaluation_env = setindex!!(evaluation_env, value, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) else dist = node_function(; args...) - value = rand(ctx.rng, dist) # just sample from the prior - logp += logpdf(dist, value) - evaluation_env = setindex!!(evaluation_env, value, vn) + if is_observation(model.g, vn) + value = AbstractPPL.get(evaluation_env, vn) + else + value = rand(ctx.rng, dist) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) + end + if model.transformed + value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) + logp += + Distributions.logpdf(dist, value) + Bijectors.logabsdetjac( + Bijectors.inverse(Bijectors.bijector(dist)), value_transformed + ) + else + logp += Distributions.logpdf(dist, value) + end end end return evaluation_env, logp @@ -389,7 +396,6 @@ function AbstractPPL.evaluate!!(model::BUGSModel) end function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) (; sorted_nodes, g, evaluation_env) = model - vi = deepcopy(evaluation_env) logp = 0.0 for vn in sorted_nodes (; is_stochastic, node_function, node_args, loop_vars) = g[vn] diff --git a/test/gibbs.jl b/test/gibbs.jl index 9a6adff58..c99e4bc73 100644 --- a/test/gibbs.jl +++ b/test/gibbs.jl @@ -1,6 +1,49 @@ -using JuliaBUGS: MHFromPrior, Gibbs +using JuliaBUGS: MHFromPrior, Gibbs, OrderedDict @testset "Simple gibbs" begin + model_def = @bugs begin + μ ~ Normal(0, 4) + σ ~ Gamma(1, 1) + for i in 1:100 + y[i] ~ Normal(μ, σ) + end + end + + μ_true = 2 + σ_true = 4 + + y = rand(Normal(μ_true, σ_true), 100) + + model = compile(model_def, (; y=y)) + model = initialize!(model, (μ=4.0, σ=6.0)) + + splr_map = OrderedDict(@varname(μ) => MHFromPrior(), @varname(σ) => MHFromPrior()) + splr = Gibbs(splr_map) + + p_s, st_init = AbstractMCMC.step( + Random.default_rng(), AbstractMCMC.LogDensityModel(model), splr + ) + + p_s, st = AbstractMCMC.step( + Random.default_rng(), AbstractMCMC.LogDensityModel(model), splr, st_init + ) + + chn = AbstractMCMC.sample( + Random.default_rng(), + model, + splr, + 10000; + # chain_type=MCMCChains.Chains, + ) + + σ_samples = [v[1] for v in chn[300:end]] + μ_samples = [v[2] for v in chn[300:end]] + + @test mean(μ_samples) ≈ μ_true rtol = 0.2 + @test mean(σ_samples) ≈ σ_true rtol = 0.2 +end + +@testset "Linear regression" begin model_def = @bugs begin # Likelihood for i in 1:N @@ -29,32 +72,41 @@ using JuliaBUGS: MHFromPrior, Gibbs model = compile(model_def, data, (;)) + sampler = Gibbs( + OrderedDict( + @varname(alpha) => MHFromPrior(), + @varname(beta) => MHFromPrior(), + @varname(sigma) => MHFromPrior(), + ), + ) + # single step p_s, st_init = AbstractMCMC.step( - Random.default_rng(), - AbstractMCMC.LogDensityModel(model), - Gibbs(model, MHFromPrior()), + Random.default_rng(), AbstractMCMC.LogDensityModel(model), sampler ) # following step p_s, st = AbstractMCMC.step( - Random.default_rng(), - AbstractMCMC.LogDensityModel(model), - Gibbs(model, MHFromPrior()), - st_init, + Random.default_rng(), AbstractMCMC.LogDensityModel(model), sampler, st_init ) # following step with sampler_map sampler_map = OrderedDict( [@varname(alpha), @varname(beta)] => HMC(0.1, 10), [@varname(sigma)] => RWMH(1) ) + p_s, st_init = AbstractMCMC.step( + Random.default_rng(), AbstractMCMC.LogDensityModel(model), Gibbs(sampler_map) + ) p_s, st = AbstractMCMC.step( - Random.default_rng(), AbstractMCMC.LogDensityModel(model), Gibbs(sampler_map), st + Random.default_rng(), + AbstractMCMC.LogDensityModel(model), + Gibbs(sampler_map), + st_init, ) # TODO: result checking is disabled because of speed and stability, revive this after improvement # sample_size = 10000 - sample_size = 10 + sample_size = 100000 chn = AbstractMCMC.sample( Random.default_rng(), model, @@ -65,8 +117,19 @@ using JuliaBUGS: MHFromPrior, Gibbs ), ), sample_size; + # chain_type=MCMCChains.Chains, discard_initial=Int(sample_size / 2), ) + + num_to_discard = Int(sample_size / 2) + alpha_samples = [v[1] for v in chn[num_to_discard:end]] + beta_samples = [v[2] for v in chn[num_to_discard:end]] + sigma_samples = [v[3] for v in chn[num_to_discard:end]] + + alpha_mean = mean(alpha_samples) + beta_mean = mean(beta_samples) + sigma_mean = mean(sigma_samples) + @test chn.name_map[:parameters] == [ :sigma :beta @@ -79,14 +142,30 @@ using JuliaBUGS: MHFromPrior, Gibbs # @test means[:sigma].nt.mean[1] ≈ 0.95 rtol = 0.2 # @test means[:gen_quant].nt.mean[1] ≈ 4.2 rtol = 0.2 - sample_size = 2000 + sample_size = 10000 hmc_chn = AbstractMCMC.sample( Random.default_rng(), model, - Gibbs(model, HMC(0.1, 10)), + Gibbs( + OrderedDict( + @varname(alpha) => HMC(0.1, 10), + @varname(beta) => HMC(0.1, 10), + @varname(sigma) => HMC(0.1, 10), + ), + ), sample_size; discard_initial=Int(sample_size / 2), ) + + num_to_discard = Int(sample_size / 2) + alpha_samples = [v[1] for v in hmc_chn[num_to_discard:end]] + beta_samples = [v[2] for v in hmc_chn[num_to_discard:end]] + sigma_samples = [v[3] for v in hmc_chn[num_to_discard:end]] + + alpha_mean = mean(alpha_samples) + beta_mean = mean(beta_samples) + sigma_mean = mean(sigma_samples) + means = mean(hmc_chn) @test means[:alpha].nt.mean[1] ≈ 2.2 rtol = 0.2 @test means[:beta].nt.mean[1] ≈ 2.1 rtol = 0.2 diff --git a/test/graphs.jl b/test/graphs.jl index 124515be4..3687d0f55 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -1,104 +1,155 @@ +using Graphs, MetaGraphsNext +using JuliaBUGS using JuliaBUGS: - stochastic_inneighbors, stochastic_neighbors, stochastic_outneighbors, markov_blanket - -test_model = @bugs begin - a ~ dnorm(f, c) - f = b - 1 - b ~ dnorm(0, 1) - c ~ dnorm(l, 1) - g = a * 2 - d ~ dnorm(g, 1) - h = g + 2 - e ~ dnorm(h, i) - i ~ dnorm(0, 1) - l ~ dnorm(0, 1) -end - -inits = ( - a=1.0, - b=2.0, - c=3.0, - d=4.0, - e=5.0, - - # f = 1.0, - # g = 2.0, - # h = 4.0, - - i=4.0, - l=-2.0, -) + markov_blanket, dfs_find_stochastic_boundary_and_deterministic_variables_en_route -model = compile(test_model, NamedTuple(), inits) +module GraphsTest +using JuliaBUGS: JuliaBUGS +using Graphs, MetaGraphsNext -g = model.g +export TestNode -a = @varname a -l = @varname l -@test Set(Symbol.(stochastic_inneighbors(g, a))) == Set([:b, :c, :f]) -@test Set(Symbol.(stochastic_outneighbors(g, a))) == Set([:d, :e, :h, :g]) - -@test Set(Symbol.(markov_blanket(g, a))) == Set([:f, :b, :d, :e, :c, :h, :g, :i]) -@test Set(Symbol.(markov_blanket(g, [a, l]))) == Set([:f, :b, :d, :e, :c, :h, :g, :i]) +struct TestNode + node_type::Int +end -c = @varname c -@test Set(Symbol.(markov_blanket(model.g, c))) == Set([:l, :a, :b, :f]) +# overload the functions for testing purposes +function JuliaBUGS.is_model_parameter( + g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int +) + return g[v].node_type == 1 +end +function JuliaBUGS.is_observation(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) + return g[v].node_type == 2 +end +function JuliaBUGS.is_deterministic(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) + return g[v].node_type == 3 +end +function JuliaBUGS.is_stochastic(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) + return g[v].node_type != 3 +end +end -cond_model = AbstractPPL.condition(model, setdiff(model.parameters, [c])) -# tests for MarkovBlanketBUGSModel constructor -@test cond_model.parameters == [c] -@test Set(Symbol.(cond_model.sorted_nodes)) == Set([:l, :a, :b, :f, :c]) +@testset "find_generated_quantities_variables" begin + using .GraphsTest + + function generate_random_dag(num_nodes::Int, p::Float64=0.3) + graph = SimpleGraph(num_nodes) + for i in 1:num_nodes + for j in 1:num_nodes + if i != j && rand() < p + add_edge!(graph, i, j) + end + end + end + + graph = Graphs.random_orientation_dag(graph) # ensure the random graph is a DAG + vertices_description = [i => TestNode(rand(1:3)) for i in 1:nv(graph)] + edges_description = [Tuple(e) => nothing for e in Graphs.edges(graph)] + return MetaGraph(graph, vertices_description, edges_description) + end -decond_model = AbstractPPL.decondition(cond_model, [a, l]) -@test Set(Symbol.(decond_model.parameters)) == Set([:a, :c, :l]) -@test Set(Symbol.(decond_model.sorted_nodes)) == - Set([:l, :b, :f, :a, :d, :e, :c, :h, :g, :i]) + # `transitiveclosure` has time complexity O(|E|⋅|V|), not fit for large graphs + # but easy to implement and understand, here we use it for reference + function find_generated_quantities_variables_with_transitive_closure( + g::MetaGraph{Int,<:SimpleDiGraph,Label,VertexData} + ) where {Label,VertexData} + _transitive_closure = Graphs.transitiveclosure(g.graph) + generated_quantities_variables = Set{Label}() + for v_id in vertices(g.graph) + if !JuliaBUGS.is_observation(g, v_id) + if all( + !Base.Fix1(JuliaBUGS.is_observation, g), + outneighbors(_transitive_closure, v_id), + ) + push!(generated_quantities_variables, MetaGraphsNext.label_for(g, v_id)) + end + end + end + + return generated_quantities_variables + end -c_value = 4.0 -mb_logp = begin - logp = 0 - logp += logpdf(dnorm(1.0, c_value), 1.0) # a - logp += logpdf(dnorm(0.0, 1.0), 2.0) # b - logp += logpdf(dnorm(0.0, 1.0), -2.0) # l - logp += logpdf(dnorm(-2.0, 1.0), c_value) # c - logp -end + @testset "random DAG with $num_nodes nodes and $p probability of edge" for num_nodes in + [ + 10, 20, 100, 500, 1000 + ], + p in [0.1, 0.3, 0.5] -# order: b, l, c, a -@test mb_logp ≈ evaluate!!(cond_model, JuliaBUGS.LogDensityContext(), [c_value])[2] rtol = - 1e-8 - -# test LogDensityContext -@test begin - logp = 0 - logp += logpdf(dnorm(1.0, 3.0), 1.0) # a, where f = 1.0 - logp += logpdf(dnorm(0.0, 1.0), 2.0) # b - logp += logpdf(dnorm(0.0, 1.0), -2.0) # l - logp += logpdf(dnorm(-2.0, 1.0), 3.0) # c - logp += logpdf(dnorm(0.0, 1.0), 4.0) # i - logp += logpdf(dnorm(2.0, 1.0), 4.0) # d, where g = 2.0 - logp += logpdf(dnorm(4.0, 4.0), 5.0) # e, where h = 4.0 - logp -end ≈ evaluate!!( - model, JuliaBUGS.LogDensityContext(), [-2.0, 4.0, 3.0, 2.0, 1.0, 4.0, 5.0] -)[2] atol = 1e-8 - -# AuxiliaryNodeInfo -test_model = @bugs begin - x[1:2] ~ dmnorm(mu[:], sigma[:, :]) - for i in 1:2 - mu[i] ~ dnorm(0, 1) + g = generate_random_dag(num_nodes, p) + @test find_generated_quantities_variables(g) == + find_generated_quantities_variables_with_transitive_closure(g) end - z[1:2, 1:2] ~ dwish(R[:, :], 2) - y ~ dnorm(x[1], x[2] + 1 + z[1, 1]) end -model = compile( - test_model, - (R=[200 0; 0 0.2], sigma=[1.0E-6 0; 0 1.0E-6]), - (x=[1.0, 2.0], z=zeros(2, 2)), -) - -# z[1,1], x[1], x[2] are auxiliary nodes created, and removed at the end -@test Set(Symbol.(labels(model.g))) == - Set([Symbol("mu[1]"), Symbol("x[1:2]"), Symbol("z[1:2, 1:2]"), Symbol("mu[2]"), :y]) +@testset "markov_blanket" begin + using .GraphsTest + + """ Mermaid code for visualizing the test graph + ```mermaid + graph TD + 1((1: Parameter)) --> 2((2: Deterministic)) + 1 --> 3((3: Parameter)) + 2 --> 4((4: Deterministic)) + 3 --> 5((5: Observation)) + 4 --> 6((6: Deterministic)) + 5 --> 6 + 5 --> 7((7: Observation)) + 6 --> 8((8: Parameter)) + 7 --> 8 + + classDef parameter fill:#f9f,stroke:#333,stroke-width:2px; + classDef deterministic fill:#bfb,stroke:#333,stroke-width:2px; + classDef observation fill:#bbf,stroke:#333,stroke-width:2px; + + class 1,3,8 parameter; + class 2,4,6 deterministic; + class 5,7 observation; + ``` + """ + + g = MetaGraph(SimpleDiGraph(); label_type=Int, vertex_data_type=TestNode) + + g[1] = TestNode(1) + g[2] = TestNode(3) + g[3] = TestNode(1) + g[4] = TestNode(3) + g[5] = TestNode(2) + g[6] = TestNode(3) + g[7] = TestNode(2) + g[8] = TestNode(1) + + add_edge!(g, 1, 2) + add_edge!(g, 1, 3) + add_edge!(g, 2, 4) + add_edge!(g, 3, 5) + add_edge!(g, 4, 6) + add_edge!(g, 5, 6) + add_edge!(g, 5, 7) + add_edge!(g, 6, 8) + add_edge!(g, 7, 8) + + # Test single node Markov blanket + @test dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, 1, MetaGraphsNext.outneighbor_labels + ) == (Set([3, 8]), Set([4, 6, 2])) + @test dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, 1, MetaGraphsNext.inneighbor_labels + ) == (Set(), Set()) + @test markov_blanket(g, 1) == Set(collect(1:8)) # should contains all the nodes + + @test dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, 5, MetaGraphsNext.outneighbor_labels + ) == (Set([7, 8]), Set([6])) + @test dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + g, 5, MetaGraphsNext.inneighbor_labels + ) == (Set([3]), Set()) + @test markov_blanket(g, 5) == Set([1, 2, 3, 4, 5, 6, 7, 8]) + + @test markov_blanket(g, 3) == Set([1, 3, 5]) + @test markov_blanket(g, 7) == Set([1, 2, 4, 5, 6, 7, 8]) + @test markov_blanket(g, 8) == Set([1, 2, 4, 5, 6, 7, 8]) + + @test markov_blanket(g, [1, 3]) == Set([1, 2, 3, 4, 5, 6, 7, 8]) + @test markov_blanket(g, (3, 7)) == Set([1, 2, 3, 4, 5, 6, 7, 8]) +end diff --git a/test/runtests.jl b/test/runtests.jl index 39a8a0709..84bea8464 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,11 +9,10 @@ using AbstractPPL using AbstractMCMC using AdvancedHMC using AdvancedMH +using ADTypes using Bijectors using Distributions -using DynamicPPL # TODO: for `gen_chains` function only, to be removed -using Graphs -using MetaGraphsNext +using Graphs, MetaGraphsNext using LinearAlgebra using LogDensityProblems using LogDensityProblemsAD diff --git a/test/utils.jl b/test/utils.jl index 2f9b55681..eb554570b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -9,7 +9,7 @@ using JuliaBUGS: CompilerUtils end loop_var, lb, ub, body = JuliaBUGS.decompose_for_expr(ex) - + @test loop_var == :i @test lb == 1 @test ub == 3 @@ -20,3 +20,17 @@ using JuliaBUGS: CompilerUtils end end end + +@testset "BangBang.setindex!!" begin + nt = (a=1, b=[1, 2, 3], c=[1, 2, 3]) + nt1 = BangBang.setindex!!(nt, 2, @varname(a)) + @test nt1.a == 2 + + nt2 = BangBang.setindex!!(nt, 5, @varname(b[1])) + @test nt2.b == [5, 2, 3] + @test nt2.b === nt.b # mutation + + nt3 = BangBang.setindex!!(nt, 2, @varname(c[1]); prefer_mutation=false) + @test nt3.c == [2, 2, 3] + @test nt3.c !== nt.c # no mutation +end