Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable elimination reimplementation #240

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export BayesianNetwork,
decondition,
decondition!,
ancestral_sampling,
is_conditionally_independent
is_conditionally_independent,
marginal_distribution,
eliminate_variables

end
83 changes: 79 additions & 4 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

A structure representing a Bayesian Network.
"""

struct BayesianNetwork{V,T,F}
graph::SimpleDiGraph{T}
"names of the variables in the network"
Expand All @@ -12,7 +13,10 @@ struct BayesianNetwork{V,T,F}
"values of each variable in the network"
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future
"distributions of the stochastic variables"
distributions::Vector{Distribution}
# A distribution can be either:
# - A fixed distribution (like Uniform(0,1))
# - A function that takes parent values and returns a distribution
distributions::Vector{Union{Distribution,Function}}
"deterministic functions of the deterministic variables"
deterministic_functions::Vector{F}
"ids of the stochastic variables"
Expand Down Expand Up @@ -114,16 +118,17 @@ Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesia
if successful, 0 otherwise.
"""
function add_stochastic_vertex!(
bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool
bn::BayesianNetwork{V,T},
name::V,
dist::Union{Distribution,Function},
is_observed::Bool
Comment on lines +113 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
bn::BayesianNetwork{V,T},
name::V,
dist::Union{Distribution,Function},
is_observed::Bool
bn::BayesianNetwork{V,T}, name::V, dist::Union{Distribution,Function}, is_observed::Bool

)::T where {V,T}
Graphs.add_vertex!(bn.graph) || return 0
id = nv(bn.graph)
push!(bn.distributions, dist)
push!(bn.is_stochastic, true)
push!(bn.is_observed, is_observed)
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.stochastic_ids, id)
return id
end

Expand Down Expand Up @@ -329,3 +334,73 @@ function is_conditionally_independent(
) where {V}
return is_conditionally_independent(bn, [X], [Y], Z)
end

function is_discrete_distribution(d::Distribution)
return d isa DiscreteDistribution
end

function get_support(d::DiscreteDistribution)
return support(d)
end

"""
marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V}

Compute the marginal distribution of a query variable using variable elimination.
"""
function marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V}
# Get query variable id
query_id = bn.names_to_ids[query_var]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Get topological ordering
ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Start recursive elimination
return eliminate_variables(bn, ordered_vertices, query_id, Dict{V,Any}())
end

"""
eliminate_variables(bn, ordered_vertices, query_id, assignments)

Helper function for variable elimination algorithm.
"""
function eliminate_variables(
bn::BayesianNetwork{V},
ordered_vertices::Vector{Int},
query_id::Int,
assignments::Dict{V,Any}
Comment on lines +375 to +378
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
bn::BayesianNetwork{V},
ordered_vertices::Vector{Int},
query_id::Int,
assignments::Dict{V,Any}
bn::BayesianNetwork{V},
ordered_vertices::Vector{Int},
query_id::Int,
assignments::Dict{V,Any},

) where {V}
# Base case: reached query variable
if isempty(ordered_vertices) || ordered_vertices[1] == query_id
dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids)
return bn.distributions[dist_idx]
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

current_id = ordered_vertices[1]
remaining_vertices = ordered_vertices[2:end]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# For current variable, create mixture over its values
components = Distribution[]
weights = Float64[]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Try both values (0 and 1) # TODO: generalize for other values
for value in [0, 1]
new_assignments = copy(assignments)
new_assignments[bn.names[current_id]] = value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Get distribution for remaining variables
component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments)
println("Components so far: ", components)
println("Current component: ", component)
push!(components, component)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Get weight from current node's distribution
dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids)
push!(weights, pdf(bn.distributions[dist_idx], value))
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Normalize weights
weights ./= sum(weights)

return MixtureModel(components, weights)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

12 changes: 11 additions & 1 deletion test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
using Test
using Distributions
using Graphs

using Pkg
Pkg.activate(".")
using JuliaBUGS
using JuliaBUGS.ProbabilisticGraphicalModels

names(JuliaBUGS.ProbabilisticGraphicalModels)

using JuliaBUGS.ProbabilisticGraphicalModels:
BayesianNetwork,
add_stochastic_vertex!,
Expand All @@ -10,7 +18,9 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
condition!,
decondition,
ancestral_sampling,
is_conditionally_independent
is_conditionally_independent,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
is_conditionally_independent,
is_conditionally_independent,

marginal_distribution,
eliminate_variables

@testset "BayesianNetwork" begin
@testset "Adding vertices" begin
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# 1. First, activate the package environment
using Pkg
Pkg.activate(".")

# 2. Load required packages
using Test
using Distributions
using Graphs
using BangBang # This is needed based on the imports

# 3. Load JuliaBUGS and its submodule
using JuliaBUGS
using JuliaBUGS.ProbabilisticGraphicalModels
using JuliaBUGS.ProbabilisticGraphicalModels:
BayesianNetwork,
add_stochastic_vertex!,
add_deterministic_vertex!,
add_edge!,
condition,
condition!,
decondition,
ancestral_sampling,
is_conditionally_independent,
marginal_distribution,
eliminate_variables
# 4. Run the specific test

@testset "Simple Discrete Chain" begin
bn = BayesianNetwork{Symbol}()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Simple chain A -> B -> C
add_stochastic_vertex!(bn, :A, Bernoulli(0.7), false)
add_stochastic_vertex!(bn, :B, Bernoulli(0.8), false)
add_stochastic_vertex!(bn, :C, Bernoulli(0.9), false)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

add_edge!(bn, :A, :B)
add_edge!(bn, :B, :C)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

ordered_vertices = topological_sort_by_dfs(bn.graph)
println(ordered_vertices)
marginal_C = marginal_distribution(bn, :C)
println(marginal_C)
end

# @testset "Mixed Graph - Variable Elimination" begin
# bn = BayesianNetwork{Symbol}()

# # X1 ~ Uniform(0,1)
# add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false)

# # X2 ~ Bernoulli(X1)
# # We need a function that creates a new Bernoulli distribution based on X1's value
# add_deterministic_vertex!(bn, :X2_dist, x1 -> Bernoulli(x1))
# add_stochastic_vertex!(bn, :X2, Bernoulli(0.5), false) # Initial dist doesn't matter
# add_edge!(bn, :X1, :X2_dist)
# add_edge!(bn, :X2_dist, :X2)

# # X3 ~ Normal(μ(X2), 1)
# # Function that creates a new Normal distribution based on X2's value
# add_deterministic_vertex!(bn, :X3_dist, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0))
# add_stochastic_vertex!(bn, :X3, Normal(0, 1), false) # Initial dist doesn't matter
# add_edge!(bn, :X2, :X3_dist)
# add_edge!(bn, :X3_dist, :X3)
# end

@testset "Mixed Graph - Variable Elimination" begin
bn = BayesianNetwork{Symbol}()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# X1 ~ Uniform(0,1)
add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# X2 ~ Bernoulli(X1)
# The distribution constructor takes the parent value and returns the appropriate distribution
conditional_dist_X2 = x1 -> Bernoulli(x1)
add_stochastic_vertex!(bn, :X2, conditional_dist_X2, false)
add_edge!(bn, :X1, :X2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# X3 ~ Normal(μ(X2), 1)
conditional_dist_X3 = x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0)
add_stochastic_vertex!(bn, :X3, conditional_dist_X3, false)
add_edge!(bn, :X2, :X3)
end
Loading