Skip to content

Commit

Permalink
Merge pull request #303 from Julia-Tempering/juliabugs
Browse files Browse the repository at this point in the history
JuliaBUGS support
  • Loading branch information
miguelbiron authored Dec 20, 2024
2 parents 07806d3 + 1ba7742 commit 72a523b
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 13 deletions.
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"

[weakdeps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

Expand All @@ -51,10 +54,13 @@ PigeonsDynamicPPLExt = "DynamicPPL"
PigeonsEnzymeExt = "Enzyme"
PigeonsForwardDiffExt = "ForwardDiff"
PigeonsHypothesisTestsExt = "HypothesisTests"
PigeonsJuliaBUGSExt = ["JuliaBUGS", "AbstractPPL", "Bijectors"]
PigeonsMCMCChainsExt = "MCMCChains"
PigeonsReverseDiffExt = "ReverseDiff"

[compat]
AbstractPPL = "0.8.4, 0.9"
Bijectors = "0.13, 0.14"
BridgeStan = "2"
DataFrames = "1"
Distributions = "0.25"
Expand All @@ -68,6 +74,7 @@ Graphs = "1"
HypothesisTests = "0.11"
Interpolations = "0.14, 0.15"
JSON = "0.21"
JuliaBUGS = "0.8"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
LogExpFunctions = "0.3"
Expand All @@ -90,10 +97,13 @@ ZipFile = "0.10"
julia = "1.8"

[extras]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
24 changes: 24 additions & 0 deletions examples/JuliaBUGS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function incomplete_count_data_model(;tau::Real=4)
model_def = @bugs("model{
for (i in 1:n) {
r[i] ~ dbern(pr[i])
pr[i] <- ilogit(y[i] * alpha1 + alpha0)
y[i] ~ dpois(mu)
}
mu ~ dgamma(1,1)
alpha0 ~ dnorm(0, 0.1)
alpha1 ~ dnorm(0, tau)
}",false,false
)
data = (
y = [
6,missing,missing,missing,missing,missing,missing,5,1,missing,1,missing,
missing,missing,2,missing,missing,0,missing,1,2,1,7,4,6,missing,missing,
missing,5,missing
],
r = [1,0,0,0,0,0,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,1,1,1,1,0,0,0,1,0],
n = 30, tau = tau
)
return compile(model_def, data)
end
incomplete_count_data(;kwargs...) = JuliaBUGSPath(incomplete_count_data_model(;kwargs...))
5 changes: 2 additions & 3 deletions ext/PigeonsDynamicPPLExt/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ initial and final states.
"""
function Pigeons.forward_sample_condition_and_explore(
model::DynamicPPL.Model,
explorer,
rng::SplittableRandom;
run_explorer::Bool = true,
explorer = nothing,
condition_on::NTuple{N,Symbol}
) where {N}
# forward simulation
Expand All @@ -37,7 +36,7 @@ function Pigeons.forward_sample_condition_and_explore(
DynamicPPL.link!!(state, DynamicPPL.SampleFromPrior(), conditioned_model)

# maybe take a step with explorer
if run_explorer
if !isnothing(explorer)
state = Pigeons.explorer_step(rng, TuringLogPotential(conditioned_model), explorer, state)
end

Expand Down
4 changes: 2 additions & 2 deletions ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ function Pigeons.invariance_test(
# iterate iid samples
for n in eachindex(initial_values)
initial_values[n] = Pigeons.forward_sample_condition_and_explore(
target, explorer, rng; run_explorer=false, simulator_kwargs...)
target, rng; simulator_kwargs...)
final_values[n] = Pigeons.forward_sample_condition_and_explore(
target, explorer, rng; simulator_kwargs...)
target, rng; explorer, simulator_kwargs...)
end

# transform vector of vectors to matrices so that iterating dimensions == iterating columns => faster
Expand Down
24 changes: 24 additions & 0 deletions ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module PigeonsJuliaBUGSExt

using Pigeons
if isdefined(Base, :get_extension)
import JuliaBUGS
using AbstractPPL # only need because we rewrite JuliaBUGS.getparams
using Bijectors # only need because we rewrite JuliaBUGS.getparams
using DocStringExtensions
using SplittableRandoms: SplittableRandom, split
using Random
else
import ..JuliaBUGS
using ..AbstractPPL # only need because we rewrite JuliaBUGS.getparams
using ..Bijectors # only need because we rewrite JuliaBUGS.getparams
using ..DocStringExtensions
using ..SplittableRandoms: SplittableRandom, split
using ..Random
end

include(joinpath(@__DIR__, "utils.jl"))
include(joinpath(@__DIR__, "interface.jl"))
include(joinpath(@__DIR__, "invariance_test.jl"))

end
97 changes: 97 additions & 0 deletions ext/PigeonsJuliaBUGSExt/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#######################################
# Path interface
#######################################

# Initialization and iid sampling
function evaluate_and_initialize(model::JuliaBUGS.BUGSModel, rng::AbstractRNG)
new_env = first(JuliaBUGS.evaluate!!(rng, model)) # sample a new evaluation environment
return JuliaBUGS.initialize!(model, new_env) # set the private_model's environment to the newly created one
end

# used for both initializing and iid sampling
# Note: state is a flattened vector of the parameters
# Also, the vector is **concretely typed**. This means that if the evaluation
# environment contains floats and integers, the latter will be cast to float.
_sample_iid(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) =
getparams(evaluate_and_initialize(model, rng)) # flatten the unobserved parameters in the model's eval environment and return

# Note: JuliaBUGS.getparams creates a new vector on each call, so it is safe
# to call _sample_iid during initialization (**sequentially**, as done as of time
# of writing) for different Replicas (i.e., they won't share the same state).
Pigeons.initialization(target::JuliaBUGSPath, rng::AbstractRNG, _::Int64) =
_sample_iid(target.model, rng)

# target is already a Path
Pigeons.create_path(target::JuliaBUGSPath, ::Inputs) = target

#######################################
# Log-potential interface
#######################################

"""
$SIGNATURES
A log-potential built from a [`JuliaBUGSPath`](@ref) for a specific inverse
temperature parameter.
$FIELDS
"""
struct JuliaBUGSLogPotential{TMod<:JuliaBUGS.BUGSModel, TF<:AbstractFloat}
"""
A deep-enough copy of the original model that allows evaluation while
avoiding race conditions between different Replicas.
"""
private_model::TMod

"""
Tempering parameter.
"""
beta::TF
end

# make a log-potential by creating a new model with independent graph and
# evaluation environment. Both of these could be modified during density
# evaluations and/or during Gibbs sampling
function Pigeons.interpolate(path::JuliaBUGSPath, beta)
model = path.model
private_model = make_private_model_copy(model)
JuliaBUGSLogPotential(private_model, beta)
end

# log_potential evaluation
(log_potential::JuliaBUGSLogPotential)(flattened_values) =
try
log_prior, _, tempered_log_joint = last(
JuliaBUGS._tempered_evaluate!!(
log_potential.private_model,
flattened_values;
temperature=log_potential.beta
)
)
# avoid potential 0*Inf (= NaN)
return iszero(log_potential.beta) ? log_prior : tempered_log_joint
catch e
(isa(e, DomainError) || isa(e, BoundsError)) && return -Inf
rethrow(e)
end

# iid sampling
function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, shared)
replica.state = _sample_iid(log_potential.private_model, replica.rng)
end

# parameter names
Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) =
[(Symbol(string(vn)) for vn in log_potential.private_model.parameters)...,:log_density]

# Parallelism invariance
Pigeons.recursive_equal(a::Union{JuliaBUGSPath,JuliaBUGSLogPotential}, b) =
Pigeons._recursive_equal(a,b)
function Pigeons.recursive_equal(a::T, b) where T <: JuliaBUGS.BUGSModel
included = (:transformed, :model_def, :data)
excluded = Tuple(setdiff(fieldnames(T), included))
Pigeons._recursive_equal(a,b,excluded)
end
# just check the betas match, the model is already checked within path
Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) =
all(lp1.beta == lp2.beta for (lp1,lp2) in zip(a,b))
35 changes: 35 additions & 0 deletions ext/PigeonsJuliaBUGSExt/invariance_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
$SIGNATURES
Implements `Pigeons.forward_sample_condition_and_explore` for running invariance
tests using a [`JuliaBUGSPath`](@ref) as target.
"""
function Pigeons.forward_sample_condition_and_explore(
model::JuliaBUGS.BUGSModel,
rng::SplittableRandom;
explorer = nothing,
condition_on = ()
)
# forward simulation (new values stored in model.evaluation_env)
model = evaluate_and_initialize(model, rng)

# maybe condition the model using the sampled observations
conditioned_model = if length(condition_on) > 0
var_group = [JuliaBUGS.VarName{sym}() for sym in condition_on] # transform Symbols into VarNames
JuliaBUGS.condition(model, var_group)
else
model
end

# maybe take a step with explorer
state = getparams(conditioned_model)
return if !isnothing(explorer)
Pigeons.explorer_step(rng, JuliaBUGSPath(conditioned_model), explorer, state)
else
state
end
end

Pigeons.forward_sample_condition_and_explore(target::JuliaBUGSPath, args...; kwargs...) =
Pigeons.forward_sample_condition_and_explore(target.model, args...; kwargs...)

63 changes: 63 additions & 0 deletions ext/PigeonsJuliaBUGSExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#=
Tweak of JuliaBUGS.getparams to allow for flattened vectors of mixed type
=#
type_join_eval_env(env) = typejoin(Set(eltype(v) for v in env)...)
function getparams(model::JuliaBUGS.BUGSModel)
param_length = if model.transformed
model.transformed_param_length
else
model.untransformed_param_length
end

# search for an umbrella type for all parameters in the model to avoid
# promotion of e.g. ints to floats. For models with a unique parameter
# type T, it holds that TMix=T.
TMix = type_join_eval_env(model.evaluation_env)
param_vals = Vector{TMix}(undef, param_length)
pos = 1
for v in model.parameters
if !model.transformed
val = AbstractPPL.get(model.evaluation_env, v)
len = model.untransformed_var_lengths[v]
if val isa AbstractArray
param_vals[pos:(pos + len - 1)] .= vec(val)
else
param_vals[pos] = val
end
else
(; node_function, loop_vars) = model.g[v]
dist = node_function(model.evaluation_env, loop_vars)
transformed_value = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
)
len = model.transformed_var_lengths[v]
if transformed_value isa AbstractArray
param_vals[pos:(pos + len - 1)] .= vec(transformed_value)
else
param_vals[pos] = transformed_value
end
end
pos += len
end
return param_vals
end

function make_private_model_copy(model::JuliaBUGS.BUGSModel)
g = deepcopy(model.g)
parameters = model.parameters
sorted_nodes = model.flattened_graph_node_data.sorted_nodes
return JuliaBUGS.BUGSModel(
model.transformed,
sum(model.untransformed_var_lengths[v] for v in parameters),
sum(model.transformed_var_lengths[v] for v in parameters),
model.untransformed_var_lengths,
model.transformed_var_lengths,
deepcopy(model.evaluation_env),
parameters,
JuliaBUGS.FlattenedGraphNodeData(g, sorted_nodes),
g,
nothing,
model.model_def,
model.data
)
end
5 changes: 2 additions & 3 deletions src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ include("includes.jl")
export pigeons, Inputs, PT,
# for running jobs:
ChildProcess, MPIProcesses,
# references:
DistributionLogPotential,
# targets:
TuringLogPotential, StanLogPotential,
TuringLogPotential, StanLogPotential, DistributionLogPotential, JuliaBUGSPath,
# some examples
toy_mvn_target, toy_stan_target,
# post-processing helpers
Expand Down Expand Up @@ -89,6 +87,7 @@ end
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(joinpath(@__DIR__, "../ext/PigeonsEnzymeExt/PigeonsEnzymeExt.jl"))
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(joinpath(@__DIR__, "../ext/PigeonsForwardDiffExt/PigeonsForwardDiffExt.jl"))
@require HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" include(joinpath(@__DIR__, "../ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl"))
@require JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" include(joinpath(@__DIR__, "../ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl"))
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(joinpath(@__DIR__, "../ext/PigeonsMCMCChainsExt/PigeonsMCMCChainsExt.jl"))
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(joinpath(@__DIR__, "../ext/PigeonsReverseDiffExt/PigeonsReverseDiffExt.jl"))
end
Expand Down
5 changes: 2 additions & 3 deletions src/explorers/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ allows direct iid sampling from the target, conditioning is not necessary.
"""
function forward_sample_condition_and_explore(
target::ScaledPrecisionNormalPath,
explorer,
rng::SplittableRandom;
run_explorer::Bool = true
explorer = nothing
)
state = initialization(target, rng, 1) # forward simulation
if run_explorer
if !isnothing(explorer)
state = explorer_step(rng, target, explorer, state)
end
return state
Expand Down
1 change: 1 addition & 0 deletions src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ include("explorers/AAPS.jl")
include("explorers/GradientBasedSampler.jl")
include("evidence/stepping_stone.jl")
include("api.jl")
include("paths/JuliaBUGSPath.jl")
11 changes: 9 additions & 2 deletions src/log_potentials/log_potentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,12 @@ Assumes the input `log_potentials` is a vector where each element is a [`log_pot
This default implementation is sufficient in most cases, but in less standard scenarios,
e.g. where the state space is infinite dimensional, this can be overridden.
"""
log_unnormalized_ratio(log_potentials::AbstractVector, numerator::Int, denominator::Int, state) =
log_potentials[numerator](state) - log_potentials[denominator](state)
function log_unnormalized_ratio(log_potentials::AbstractVector, numerator::Int, denominator::Int, state)
lp_num = log_potentials[numerator](state)
lp_den = log_potentials[denominator](state)
ans = lp_num-lp_den
if isnan(ans)
error("Got NaN log-unnormalized ratio; Dumping information:\n\tlp_num=$lp_num\n\tlp_den=$lp_den\n\tState=$state")
end
return ans
end
Loading

0 comments on commit 72a523b

Please sign in to comment.