diff --git a/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl b/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl index 29c261ab..8139b775 100644 --- a/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl +++ b/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl @@ -15,6 +15,8 @@ export BayesianNetwork, decondition, decondition!, ancestral_sampling, - is_conditionally_independent + is_conditionally_independent, + marginal_distribution, + eliminate_variables end diff --git a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl index 3eca2da5..93df0c62 100644 --- a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl +++ b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -3,38 +3,34 @@ A structure representing a Bayesian Network. """ -struct BayesianNetwork{V,T,F} + +# First, modify the BayesianNetwork struct definition +struct BayesianNetwork{V,T} graph::SimpleDiGraph{T} - "names of the variables in the network" names::Vector{V} - "mapping from variable names to ids" names_to_ids::Dict{V,T} - "values of each variable in the network" - values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future - "distributions of the stochastic variables" - distributions::Vector{Distribution} - "deterministic functions of the deterministic variables" - deterministic_functions::Vector{F} - "ids of the stochastic variables" - stochastic_ids::Vector{T} - "ids of the deterministic variables" - deterministic_ids::Vector{T} - is_stochastic::BitVector + values::Dict{V,Any} + distributions::Vector{Union{Distribution,Function}} is_observed::BitVector + is_stochastic::BitVector + stochastic_ids::Vector{Int} + deterministic_ids::Vector{Int} + deterministic_functions::Vector{Function} end +# Then, modify the constructor to match exactly function BayesianNetwork{V}() where {V} - return BayesianNetwork( - SimpleDiGraph{Int}(), # by default, vertex ids are integers - V[], - Dict{V,Int}(), - Dict{V,Any}(), - Distribution[], - Any[], - Int[], - Int[], - BitVector(), - BitVector(), + return BayesianNetwork{V,Int}( + SimpleDiGraph{Int}(), # graph + V[], # names + Dict{V,Int}(), # names_to_ids + Dict{V,Any}(), # values + Union{Distribution,Function}[], # distributions + BitVector(), # is_observed + BitVector(), # is_stochastic + Int[], # stochastic_ids + Int[], # deterministic_ids + Function[] # deterministic_functions - Added this ) end @@ -114,13 +110,16 @@ Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesia if successful, 0 otherwise. """ function add_stochastic_vertex!( - bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool + bn::BayesianNetwork{V,T}, + name::V, + dist::Union{Distribution,Function}, + is_observed::Bool )::T where {V,T} Graphs.add_vertex!(bn.graph) || return 0 id = nv(bn.graph) push!(bn.distributions, dist) - push!(bn.is_stochastic, true) push!(bn.is_observed, is_observed) + push!(bn.is_stochastic, true) push!(bn.names, name) bn.names_to_ids[name] = id push!(bn.stochastic_ids, id) @@ -329,3 +328,106 @@ function is_conditionally_independent( ) where {V} return is_conditionally_independent(bn, [X], [Y], Z) end + +function is_discrete_distribution(d::Distribution) + return d isa DiscreteDistribution +end + +function get_support(d::DiscreteDistribution) + return support(d) +end + +""" + marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V} + +Compute the marginal distribution of a query variable using variable elimination. +""" +function marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V} + # Get query variable id + query_id = bn.names_to_ids[query_var] + + # Get topological ordering + ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph) + + # Start recursive elimination + return eliminate_variables(bn, ordered_vertices, query_id, Dict{V,Any}()) +end +# Helper functions to evaluate distributions +function evaluate_distribution(dist::Distribution, _) + return dist +end + +function evaluate_distribution(dist_func::Function, parent_values) + # Skip evaluation if any parent value is nothing + if any(isnothing, parent_values) + return nothing + end + + # If there's only one parent value, pass it directly instead of splatting + if length(parent_values) == 1 + return dist_func(parent_values[1]) + else + return dist_func(parent_values...) + end +end + +function eliminate_variables( + bn::BayesianNetwork{V}, + ordered_vertices::Vector{Int}, + query_id::Int, + assignments::Dict{V,Any} +) where {V} + # Base case: reached query variable + if isempty(ordered_vertices) || ordered_vertices[1] == query_id + dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids) + current_dist = bn.distributions[dist_idx] + + # Get parent values if it's a conditional distribution + parent_ids = Graphs.inneighbors(bn.graph, query_id) + parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids] + + result = evaluate_distribution(current_dist, parent_values) + return isnothing(result) ? current_dist : result + end + + current_id = ordered_vertices[1] + remaining_vertices = ordered_vertices[2:end] + + # First, get the type of distribution we'll be dealing with + dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids) + current_dist = bn.distributions[dist_idx] + parent_ids = Graphs.inneighbors(bn.graph, query_id) + parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids] + test_dist = evaluate_distribution(current_dist, parent_values) + test_dist = isnothing(test_dist) ? current_dist : test_dist + + # Initialize components with the correct type + if test_dist isa ContinuousUnivariateDistribution + components = Vector{ContinuousUnivariateDistribution}() + else + components = Vector{DiscreteUnivariateDistribution}() + end + weights = Float64[] + + # Try both values (0 and 1) + for value in [0, 1] + new_assignments = copy(assignments) + new_assignments[bn.names[current_id]] = value + + component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments) + push!(components, component) + + # Get weight from current node's distribution + dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids) + current_dist = bn.distributions[dist_idx] + parent_ids = Graphs.inneighbors(bn.graph, current_id) + parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids] + + dist = evaluate_distribution(current_dist, parent_values) + dist = isnothing(dist) ? current_dist : dist + push!(weights, pdf(dist, value)) + end + + weights ./= sum(weights) + return MixtureModel(components, weights) +end \ No newline at end of file diff --git a/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl index 947a1065..83473255 100644 --- a/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl +++ b/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -1,6 +1,14 @@ using Test using Distributions using Graphs + +using Pkg +Pkg.activate(".") +using JuliaBUGS +using JuliaBUGS.ProbabilisticGraphicalModels + +names(JuliaBUGS.ProbabilisticGraphicalModels) + using JuliaBUGS.ProbabilisticGraphicalModels: BayesianNetwork, add_stochastic_vertex!, @@ -10,7 +18,9 @@ using JuliaBUGS.ProbabilisticGraphicalModels: condition!, decondition, ancestral_sampling, - is_conditionally_independent + is_conditionally_independent, + marginal_distribution, + eliminate_variables @testset "BayesianNetwork" begin @testset "Adding vertices" begin diff --git a/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl b/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl new file mode 100644 index 00000000..e37c1497 --- /dev/null +++ b/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl @@ -0,0 +1,108 @@ +# 1. First, activate the package environment +using Pkg +Pkg.activate(".") + +# 2. Load required packages +using Test +using Distributions +using Graphs +using BangBang + +# 3. Load JuliaBUGS and its submodule +using JuliaBUGS +using JuliaBUGS.ProbabilisticGraphicalModels +using JuliaBUGS.ProbabilisticGraphicalModels: + BayesianNetwork, + add_stochastic_vertex!, + add_edge!, + condition, + marginal_distribution, + eliminate_variables + +@testset "Mixed Graph - Variable Elimination" begin + bn = BayesianNetwork{Symbol}() + + # X1 ~ Uniform(0,1) + add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) + + # X2 ~ Bernoulli(X1) + function x2_distribution(x1) + return Bernoulli(x1) + end + add_stochastic_vertex!(bn, :X2, x2_distribution, false) + add_edge!(bn, :X1, :X2) + + # X3 ~ Normal(μ(X2), 1) + function x3_distribution(x2) + return Normal(x2 == 1 ? 10.0 : 2.0, 1.0) + end + add_stochastic_vertex!(bn, :X3, x3_distribution, false) + add_edge!(bn, :X2, :X3) + + # Test graph structure + @test has_edge(bn.graph, 1, 2) # X1 -> X2 + @test has_edge(bn.graph, 2, 3) # X2 -> X3 + + # Test conditional distributions + # Test X2's distribution given X1 + bn_cond_x1 = condition(bn, Dict(:X1 => 0.7)) + marginal_x2 = marginal_distribution(bn_cond_x1, :X2) + @test marginal_x2 isa Bernoulli + @test mean(marginal_x2) ≈ 0.7 + + # Test X3's distribution given X2 + bn_cond_x2_0 = condition(bn, Dict(:X2 => 0)) + marginal_x3_0 = marginal_distribution(bn_cond_x2_0, :X3) + @test marginal_x3_0 isa Normal + @test mean(marginal_x3_0) ≈ 2.0 + @test std(marginal_x3_0) ≈ 1.0 + + bn_cond_x2_1 = condition(bn, Dict(:X2 => 1)) + marginal_x3_1 = marginal_distribution(bn_cond_x2_1, :X3) + @test marginal_x3_1 isa Normal + @test mean(marginal_x3_1) ≈ 10.0 + @test std(marginal_x3_1) ≈ 1.0 + + # Test full chain inference + ordered_vertices = [1, 2] # Eliminate X1, then X2 + query_id = 3 # Query X3 + result = eliminate_variables(bn, ordered_vertices, query_id, Dict{Symbol,Any}()) + + # The result should be a mixture of Normal distributions + @test result isa MixtureModel +end + +@testset "Marginal Distribution P(X3|X1)" begin + bn = BayesianNetwork{Symbol}() + + # X1 ~ Uniform(0,1) + add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) + + # X2 ~ Bernoulli(X1) + add_stochastic_vertex!(bn, :X2, x1 -> Bernoulli(x1), false) + add_edge!(bn, :X1, :X2) + + # X3 ~ Normal(μ(X2), 1) + add_stochastic_vertex!(bn, :X3, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0), false) + add_edge!(bn, :X2, :X3) + + # Test P(X3|X1=0.7) + bn_cond = condition(bn, Dict(:X1 => 0.7)) + marginal_x3 = marginal_distribution(bn_cond, :X3) + + @test marginal_x3 isa MixtureModel + @test length(marginal_x3.components) == 2 + @test marginal_x3.components[1] isa Normal + @test marginal_x3.components[2] isa Normal + + # When X1 = 0.7: + # P(X2=0) = 0.3, P(X2=1) = 0.7 + @test marginal_x3.prior.p ≈ [0.3, 0.7] + + # Component means should be 2 and 10 + @test mean(marginal_x3.components[1]) ≈ 2.0 + @test mean(marginal_x3.components[2]) ≈ 10.0 + + # Overall mean should be weighted average + @test mean(marginal_x3) ≈ 2.0 * 0.3 + 10.0 * 0.7 +end \ No newline at end of file