-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Variable elimination reimplementation #240
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||||||||||||||
|
||||||||||||||||||
A structure representing a Bayesian Network. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
struct BayesianNetwork{V,T,F} | ||||||||||||||||||
graph::SimpleDiGraph{T} | ||||||||||||||||||
"names of the variables in the network" | ||||||||||||||||||
|
@@ -12,7 +13,10 @@ struct BayesianNetwork{V,T,F} | |||||||||||||||||
"values of each variable in the network" | ||||||||||||||||||
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future | ||||||||||||||||||
"distributions of the stochastic variables" | ||||||||||||||||||
distributions::Vector{Distribution} | ||||||||||||||||||
# A distribution can be either: | ||||||||||||||||||
# - A fixed distribution (like Uniform(0,1)) | ||||||||||||||||||
# - A function that takes parent values and returns a distribution | ||||||||||||||||||
distributions::Vector{Union{Distribution,Function}} | ||||||||||||||||||
"deterministic functions of the deterministic variables" | ||||||||||||||||||
deterministic_functions::Vector{F} | ||||||||||||||||||
"ids of the stochastic variables" | ||||||||||||||||||
|
@@ -114,16 +118,17 @@ Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesia | |||||||||||||||||
if successful, 0 otherwise. | ||||||||||||||||||
""" | ||||||||||||||||||
function add_stochastic_vertex!( | ||||||||||||||||||
bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool | ||||||||||||||||||
bn::BayesianNetwork{V,T}, | ||||||||||||||||||
name::V, | ||||||||||||||||||
dist::Union{Distribution,Function}, | ||||||||||||||||||
is_observed::Bool | ||||||||||||||||||
)::T where {V,T} | ||||||||||||||||||
Graphs.add_vertex!(bn.graph) || return 0 | ||||||||||||||||||
id = nv(bn.graph) | ||||||||||||||||||
push!(bn.distributions, dist) | ||||||||||||||||||
push!(bn.is_stochastic, true) | ||||||||||||||||||
push!(bn.is_observed, is_observed) | ||||||||||||||||||
push!(bn.names, name) | ||||||||||||||||||
bn.names_to_ids[name] = id | ||||||||||||||||||
push!(bn.stochastic_ids, id) | ||||||||||||||||||
return id | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -329,3 +334,73 @@ function is_conditionally_independent( | |||||||||||||||||
) where {V} | ||||||||||||||||||
return is_conditionally_independent(bn, [X], [Y], Z) | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
function is_discrete_distribution(d::Distribution) | ||||||||||||||||||
return d isa DiscreteDistribution | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
function get_support(d::DiscreteDistribution) | ||||||||||||||||||
return support(d) | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
""" | ||||||||||||||||||
marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V} | ||||||||||||||||||
|
||||||||||||||||||
Compute the marginal distribution of a query variable using variable elimination. | ||||||||||||||||||
""" | ||||||||||||||||||
function marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V} | ||||||||||||||||||
# Get query variable id | ||||||||||||||||||
query_id = bn.names_to_ids[query_var] | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# Get topological ordering | ||||||||||||||||||
ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph) | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# Start recursive elimination | ||||||||||||||||||
return eliminate_variables(bn, ordered_vertices, query_id, Dict{V,Any}()) | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
""" | ||||||||||||||||||
eliminate_variables(bn, ordered_vertices, query_id, assignments) | ||||||||||||||||||
|
||||||||||||||||||
Helper function for variable elimination algorithm. | ||||||||||||||||||
""" | ||||||||||||||||||
function eliminate_variables( | ||||||||||||||||||
bn::BayesianNetwork{V}, | ||||||||||||||||||
ordered_vertices::Vector{Int}, | ||||||||||||||||||
query_id::Int, | ||||||||||||||||||
assignments::Dict{V,Any} | ||||||||||||||||||
Comment on lines
+375
to
+378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
) where {V} | ||||||||||||||||||
# Base case: reached query variable | ||||||||||||||||||
if isempty(ordered_vertices) || ordered_vertices[1] == query_id | ||||||||||||||||||
dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids) | ||||||||||||||||||
return bn.distributions[dist_idx] | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
current_id = ordered_vertices[1] | ||||||||||||||||||
remaining_vertices = ordered_vertices[2:end] | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# For current variable, create mixture over its values | ||||||||||||||||||
components = Distribution[] | ||||||||||||||||||
weights = Float64[] | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# Try both values (0 and 1) # TODO: generalize for other values | ||||||||||||||||||
for value in [0, 1] | ||||||||||||||||||
new_assignments = copy(assignments) | ||||||||||||||||||
new_assignments[bn.names[current_id]] = value | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# Get distribution for remaining variables | ||||||||||||||||||
component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments) | ||||||||||||||||||
println("Components so far: ", components) | ||||||||||||||||||
println("Current component: ", component) | ||||||||||||||||||
push!(components, component) | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# Get weight from current node's distribution | ||||||||||||||||||
dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids) | ||||||||||||||||||
push!(weights, pdf(bn.distributions[dist_idx], value)) | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# Normalize weights | ||||||||||||||||||
weights ./= sum(weights) | ||||||||||||||||||
|
||||||||||||||||||
return MixtureModel(components, weights) | ||||||||||||||||||
end | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,14 @@ | ||||||
using Test | ||||||
using Distributions | ||||||
using Graphs | ||||||
|
||||||
using Pkg | ||||||
Pkg.activate(".") | ||||||
using JuliaBUGS | ||||||
using JuliaBUGS.ProbabilisticGraphicalModels | ||||||
|
||||||
names(JuliaBUGS.ProbabilisticGraphicalModels) | ||||||
|
||||||
using JuliaBUGS.ProbabilisticGraphicalModels: | ||||||
BayesianNetwork, | ||||||
add_stochastic_vertex!, | ||||||
|
@@ -10,7 +18,9 @@ using JuliaBUGS.ProbabilisticGraphicalModels: | |||||
condition!, | ||||||
decondition, | ||||||
ancestral_sampling, | ||||||
is_conditionally_independent | ||||||
is_conditionally_independent, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
marginal_distribution, | ||||||
eliminate_variables | ||||||
|
||||||
@testset "BayesianNetwork" begin | ||||||
@testset "Adding vertices" begin | ||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,82 @@ | ||||
# 1. First, activate the package environment | ||||
using Pkg | ||||
Pkg.activate(".") | ||||
|
||||
# 2. Load required packages | ||||
using Test | ||||
using Distributions | ||||
using Graphs | ||||
using BangBang # This is needed based on the imports | ||||
|
||||
# 3. Load JuliaBUGS and its submodule | ||||
using JuliaBUGS | ||||
using JuliaBUGS.ProbabilisticGraphicalModels | ||||
using JuliaBUGS.ProbabilisticGraphicalModels: | ||||
BayesianNetwork, | ||||
add_stochastic_vertex!, | ||||
add_deterministic_vertex!, | ||||
add_edge!, | ||||
condition, | ||||
condition!, | ||||
decondition, | ||||
ancestral_sampling, | ||||
is_conditionally_independent, | ||||
marginal_distribution, | ||||
eliminate_variables | ||||
# 4. Run the specific test | ||||
|
||||
@testset "Simple Discrete Chain" begin | ||||
bn = BayesianNetwork{Symbol}() | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
# Simple chain A -> B -> C | ||||
add_stochastic_vertex!(bn, :A, Bernoulli(0.7), false) | ||||
add_stochastic_vertex!(bn, :B, Bernoulli(0.8), false) | ||||
add_stochastic_vertex!(bn, :C, Bernoulli(0.9), false) | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
add_edge!(bn, :A, :B) | ||||
add_edge!(bn, :B, :C) | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
ordered_vertices = topological_sort_by_dfs(bn.graph) | ||||
println(ordered_vertices) | ||||
marginal_C = marginal_distribution(bn, :C) | ||||
println(marginal_C) | ||||
end | ||||
|
||||
# @testset "Mixed Graph - Variable Elimination" begin | ||||
# bn = BayesianNetwork{Symbol}() | ||||
|
||||
# # X1 ~ Uniform(0,1) | ||||
# add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) | ||||
|
||||
# # X2 ~ Bernoulli(X1) | ||||
# # We need a function that creates a new Bernoulli distribution based on X1's value | ||||
# add_deterministic_vertex!(bn, :X2_dist, x1 -> Bernoulli(x1)) | ||||
# add_stochastic_vertex!(bn, :X2, Bernoulli(0.5), false) # Initial dist doesn't matter | ||||
# add_edge!(bn, :X1, :X2_dist) | ||||
# add_edge!(bn, :X2_dist, :X2) | ||||
|
||||
# # X3 ~ Normal(μ(X2), 1) | ||||
# # Function that creates a new Normal distribution based on X2's value | ||||
# add_deterministic_vertex!(bn, :X3_dist, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0)) | ||||
# add_stochastic_vertex!(bn, :X3, Normal(0, 1), false) # Initial dist doesn't matter | ||||
# add_edge!(bn, :X2, :X3_dist) | ||||
# add_edge!(bn, :X3_dist, :X3) | ||||
# end | ||||
|
||||
@testset "Mixed Graph - Variable Elimination" begin | ||||
bn = BayesianNetwork{Symbol}() | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
# X1 ~ Uniform(0,1) | ||||
add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
# X2 ~ Bernoulli(X1) | ||||
# The distribution constructor takes the parent value and returns the appropriate distribution | ||||
conditional_dist_X2 = x1 -> Bernoulli(x1) | ||||
add_stochastic_vertex!(bn, :X2, conditional_dist_X2, false) | ||||
add_edge!(bn, :X1, :X2) | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
# X3 ~ Normal(μ(X2), 1) | ||||
conditional_dist_X3 = x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0) | ||||
add_stochastic_vertex!(bn, :X3, conditional_dist_X3, false) | ||||
add_edge!(bn, :X2, :X3) | ||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶