Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JuliaBUGS support #303

Merged
merged 34 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f453abd
automatically build prior
miguelbiron Dec 2, 2024
51f70cb
fix prior extraction + test
miguelbiron Dec 3, 2024
81d6612
fix missing dep
miguelbiron Dec 3, 2024
f508bf4
init + sample_iid + logpotential eval
miguelbiron Dec 3, 2024
308554c
logpotential eval test
miguelbiron Dec 3, 2024
c6a9b41
rewrite JuliaBUGSExt using tempered_evaluate
miguelbiron Dec 7, 2024
2882141
Domain error try-catch and sample_iid update
serenlee Dec 10, 2024
f25d5ff
finer catch + more tests
miguelbiron Dec 10, 2024
a6d03cf
reproducible initialization
miguelbiron Dec 10, 2024
2908fea
invariance test
miguelbiron Dec 11, 2024
536930d
fix bug
miguelbiron Dec 11, 2024
d1f29ce
fix
miguelbiron Dec 11, 2024
b427da7
back to state as namedtuple
miguelbiron Dec 12, 2024
cfe5491
bump version
miguelbiron Dec 11, 2024
d7ebbcb
Merge branch 'main' into juliabugs
miguelbiron Dec 12, 2024
7e4aade
Revert "fix"
miguelbiron Dec 12, 2024
ece6290
Reapply "fix"
miguelbiron Dec 12, 2024
ae3c821
Revert "bump version"
miguelbiron Dec 12, 2024
0973ba4
Revert "back to state as namedtuple"
miguelbiron Dec 12, 2024
7db1709
mixed eltype vector
miguelbiron Dec 12, 2024
4c6c6e5
test incomplete count data model
miguelbiron Dec 12, 2024
6b4bf0a
bump JuliaBUGS compat
miguelbiron Dec 12, 2024
4d73dc3
match new version number
miguelbiron Dec 12, 2024
2a19095
Bijectors 0.13 version compat
serenlee Dec 14, 2024
e922481
AbstractPPL ver
serenlee Dec 14, 2024
c4218fa
Merge branch 'main' into juliabugs
miguelbiron Dec 16, 2024
722dde8
finer mix type
miguelbiron Dec 16, 2024
0e4bb72
fix private model builder
miguelbiron Dec 17, 2024
6e0c0cb
Add test
serenlee Dec 18, 2024
4991e63
Fix Float64 -> Real
serenlee Dec 19, 2024
d4a6071
fix NaN
miguelbiron Dec 19, 2024
0987925
sample names
miguelbiron Dec 19, 2024
0a22c4a
simple check for names
miguelbiron Dec 20, 2024
1ba7742
parallelism invariance
miguelbiron Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"

[weakdeps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

Expand All @@ -51,10 +54,13 @@ PigeonsDynamicPPLExt = "DynamicPPL"
PigeonsEnzymeExt = "Enzyme"
PigeonsForwardDiffExt = "ForwardDiff"
PigeonsHypothesisTestsExt = "HypothesisTests"
PigeonsJuliaBUGSExt = ["JuliaBUGS", "AbstractPPL", "Bijectors"]
PigeonsMCMCChainsExt = "MCMCChains"
PigeonsReverseDiffExt = "ReverseDiff"

[compat]
AbstractPPL = "0.8.4, 0.9"
Bijectors = "0.13, 0.14"
BridgeStan = "2"
DataFrames = "1"
Distributions = "0.25"
Expand All @@ -68,6 +74,7 @@ Graphs = "1"
HypothesisTests = "0.11"
Interpolations = "0.14, 0.15"
JSON = "0.21"
JuliaBUGS = "0.8"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
LogExpFunctions = "0.3"
Expand All @@ -90,10 +97,13 @@ ZipFile = "0.10"
julia = "1.8"

[extras]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
24 changes: 24 additions & 0 deletions examples/JuliaBUGS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function incomplete_count_data_model(;tau::Real=4)
model_def = @bugs("model{
for (i in 1:n) {
r[i] ~ dbern(pr[i])
pr[i] <- ilogit(y[i] * alpha1 + alpha0)
y[i] ~ dpois(mu)
}
mu ~ dgamma(1,1)
alpha0 ~ dnorm(0, 0.1)
alpha1 ~ dnorm(0, tau)
}",false,false
)
data = (
y = [
6,missing,missing,missing,missing,missing,missing,5,1,missing,1,missing,
missing,missing,2,missing,missing,0,missing,1,2,1,7,4,6,missing,missing,
missing,5,missing
],
r = [1,0,0,0,0,0,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,1,1,1,1,0,0,0,1,0],
n = 30, tau = tau
)
return compile(model_def, data)
end
incomplete_count_data(;kwargs...) = JuliaBUGSPath(incomplete_count_data_model(;kwargs...))
5 changes: 2 additions & 3 deletions ext/PigeonsDynamicPPLExt/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ initial and final states.
"""
function Pigeons.forward_sample_condition_and_explore(
model::DynamicPPL.Model,
explorer,
rng::SplittableRandom;
run_explorer::Bool = true,
explorer = nothing,
condition_on::NTuple{N,Symbol}
) where {N}
# forward simulation
Expand All @@ -37,7 +36,7 @@ function Pigeons.forward_sample_condition_and_explore(
DynamicPPL.link!!(state, DynamicPPL.SampleFromPrior(), conditioned_model)

# maybe take a step with explorer
if run_explorer
if !isnothing(explorer)
state = Pigeons.explorer_step(rng, TuringLogPotential(conditioned_model), explorer, state)
end

Expand Down
4 changes: 2 additions & 2 deletions ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ function Pigeons.invariance_test(
# iterate iid samples
for n in eachindex(initial_values)
initial_values[n] = Pigeons.forward_sample_condition_and_explore(
target, explorer, rng; run_explorer=false, simulator_kwargs...)
target, rng; simulator_kwargs...)
final_values[n] = Pigeons.forward_sample_condition_and_explore(
target, explorer, rng; simulator_kwargs...)
target, rng; explorer, simulator_kwargs...)
end

# transform vector of vectors to matrices so that iterating dimensions == iterating columns => faster
Expand Down
24 changes: 24 additions & 0 deletions ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module PigeonsJuliaBUGSExt

using Pigeons
if isdefined(Base, :get_extension)
import JuliaBUGS
using AbstractPPL # only need because we rewrite JuliaBUGS.getparams
using Bijectors # only need because we rewrite JuliaBUGS.getparams
using DocStringExtensions
using SplittableRandoms: SplittableRandom, split
using Random
else
import ..JuliaBUGS
using ..AbstractPPL # only need because we rewrite JuliaBUGS.getparams
using ..Bijectors # only need because we rewrite JuliaBUGS.getparams
using ..DocStringExtensions
using ..SplittableRandoms: SplittableRandom, split
using ..Random
end

include(joinpath(@__DIR__, "utils.jl"))
include(joinpath(@__DIR__, "interface.jl"))
include(joinpath(@__DIR__, "invariance_test.jl"))

end
97 changes: 97 additions & 0 deletions ext/PigeonsJuliaBUGSExt/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#######################################
# Path interface
#######################################

# Initialization and iid sampling
function evaluate_and_initialize(model::JuliaBUGS.BUGSModel, rng::AbstractRNG)
new_env = first(JuliaBUGS.evaluate!!(rng, model)) # sample a new evaluation environment
return JuliaBUGS.initialize!(model, new_env) # set the private_model's environment to the newly created one
end

# used for both initializing and iid sampling
# Note: state is a flattened vector of the parameters
# Also, the vector is **concretely typed**. This means that if the evaluation
# environment contains floats and integers, the latter will be cast to float.
_sample_iid(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) =
getparams(evaluate_and_initialize(model, rng)) # flatten the unobserved parameters in the model's eval environment and return

# Note: JuliaBUGS.getparams creates a new vector on each call, so it is safe
# to call _sample_iid during initialization (**sequentially**, as done as of time
# of writing) for different Replicas (i.e., they won't share the same state).
Pigeons.initialization(target::JuliaBUGSPath, rng::AbstractRNG, _::Int64) =
_sample_iid(target.model, rng)

# target is already a Path
Pigeons.create_path(target::JuliaBUGSPath, ::Inputs) = target

#######################################
# Log-potential interface
#######################################

"""
$SIGNATURES

A log-potential built from a [`JuliaBUGSPath`](@ref) for a specific inverse
temperature parameter.

$FIELDS
"""
struct JuliaBUGSLogPotential{TMod<:JuliaBUGS.BUGSModel, TF<:AbstractFloat}
"""
A deep-enough copy of the original model that allows evaluation while
avoiding race conditions between different Replicas.
"""
private_model::TMod

"""
Tempering parameter.
"""
beta::TF
end

# make a log-potential by creating a new model with independent graph and
# evaluation environment. Both of these could be modified during density
# evaluations and/or during Gibbs sampling
function Pigeons.interpolate(path::JuliaBUGSPath, beta)
model = path.model
private_model = make_private_model_copy(model)
JuliaBUGSLogPotential(private_model, beta)
end

# log_potential evaluation
(log_potential::JuliaBUGSLogPotential)(flattened_values) =
try
log_prior, _, tempered_log_joint = last(
JuliaBUGS._tempered_evaluate!!(
log_potential.private_model,
flattened_values;
temperature=log_potential.beta
)
)
# avoid potential 0*Inf (= NaN)
return iszero(log_potential.beta) ? log_prior : tempered_log_joint
catch e
(isa(e, DomainError) || isa(e, BoundsError)) && return -Inf
rethrow(e)
end

# iid sampling
function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, shared)
replica.state = _sample_iid(log_potential.private_model, replica.rng)
end

# parameter names
Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) =
[(Symbol(string(vn)) for vn in log_potential.private_model.parameters)...,:log_density]

# Parallelism invariance
Pigeons.recursive_equal(a::Union{JuliaBUGSPath,JuliaBUGSLogPotential}, b) =
Pigeons._recursive_equal(a,b)
function Pigeons.recursive_equal(a::T, b) where T <: JuliaBUGS.BUGSModel
included = (:transformed, :model_def, :data)
excluded = Tuple(setdiff(fieldnames(T), included))
Pigeons._recursive_equal(a,b,excluded)
end
# just check the betas match, the model is already checked within path
Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) =
all(lp1.beta == lp2.beta for (lp1,lp2) in zip(a,b))
35 changes: 35 additions & 0 deletions ext/PigeonsJuliaBUGSExt/invariance_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
$SIGNATURES

Implements `Pigeons.forward_sample_condition_and_explore` for running invariance
tests using a [`JuliaBUGSPath`](@ref) as target.
"""
function Pigeons.forward_sample_condition_and_explore(
model::JuliaBUGS.BUGSModel,
rng::SplittableRandom;
explorer = nothing,
condition_on = ()
)
# forward simulation (new values stored in model.evaluation_env)
model = evaluate_and_initialize(model, rng)

# maybe condition the model using the sampled observations
conditioned_model = if length(condition_on) > 0
var_group = [JuliaBUGS.VarName{sym}() for sym in condition_on] # transform Symbols into VarNames
JuliaBUGS.condition(model, var_group)
else
model
end

# maybe take a step with explorer
state = getparams(conditioned_model)
return if !isnothing(explorer)
Pigeons.explorer_step(rng, JuliaBUGSPath(conditioned_model), explorer, state)
else
state
end
end

Pigeons.forward_sample_condition_and_explore(target::JuliaBUGSPath, args...; kwargs...) =
Pigeons.forward_sample_condition_and_explore(target.model, args...; kwargs...)

63 changes: 63 additions & 0 deletions ext/PigeonsJuliaBUGSExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#=
Tweak of JuliaBUGS.getparams to allow for flattened vectors of mixed type
=#
type_join_eval_env(env) = typejoin(Set(eltype(v) for v in env)...)
function getparams(model::JuliaBUGS.BUGSModel)
param_length = if model.transformed
model.transformed_param_length
else
model.untransformed_param_length
end

# search for an umbrella type for all parameters in the model to avoid
# promotion of e.g. ints to floats. For models with a unique parameter
# type T, it holds that TMix=T.
TMix = type_join_eval_env(model.evaluation_env)
param_vals = Vector{TMix}(undef, param_length)
pos = 1
for v in model.parameters
if !model.transformed
val = AbstractPPL.get(model.evaluation_env, v)
len = model.untransformed_var_lengths[v]
if val isa AbstractArray
param_vals[pos:(pos + len - 1)] .= vec(val)
else
param_vals[pos] = val
end
else
(; node_function, loop_vars) = model.g[v]
dist = node_function(model.evaluation_env, loop_vars)
transformed_value = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
)
len = model.transformed_var_lengths[v]
if transformed_value isa AbstractArray
param_vals[pos:(pos + len - 1)] .= vec(transformed_value)
else
param_vals[pos] = transformed_value
end
end
pos += len
end
return param_vals
end

function make_private_model_copy(model::JuliaBUGS.BUGSModel)
g = deepcopy(model.g)
parameters = model.parameters
sorted_nodes = model.flattened_graph_node_data.sorted_nodes
return JuliaBUGS.BUGSModel(
model.transformed,
sum(model.untransformed_var_lengths[v] for v in parameters),
sum(model.transformed_var_lengths[v] for v in parameters),
model.untransformed_var_lengths,
model.transformed_var_lengths,
deepcopy(model.evaluation_env),
parameters,
JuliaBUGS.FlattenedGraphNodeData(g, sorted_nodes),
g,
nothing,
model.model_def,
model.data
)
end
5 changes: 2 additions & 3 deletions src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ include("includes.jl")
export pigeons, Inputs, PT,
# for running jobs:
ChildProcess, MPIProcesses,
# references:
DistributionLogPotential,
# targets:
TuringLogPotential, StanLogPotential,
TuringLogPotential, StanLogPotential, DistributionLogPotential, JuliaBUGSPath,
# some examples
toy_mvn_target, toy_stan_target,
# post-processing helpers
Expand Down Expand Up @@ -89,6 +87,7 @@ end
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(joinpath(@__DIR__, "../ext/PigeonsEnzymeExt/PigeonsEnzymeExt.jl"))
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(joinpath(@__DIR__, "../ext/PigeonsForwardDiffExt/PigeonsForwardDiffExt.jl"))
@require HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" include(joinpath(@__DIR__, "../ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl"))
@require JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" include(joinpath(@__DIR__, "../ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl"))
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(joinpath(@__DIR__, "../ext/PigeonsMCMCChainsExt/PigeonsMCMCChainsExt.jl"))
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(joinpath(@__DIR__, "../ext/PigeonsReverseDiffExt/PigeonsReverseDiffExt.jl"))
end
Expand Down
5 changes: 2 additions & 3 deletions src/explorers/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ allows direct iid sampling from the target, conditioning is not necessary.
"""
function forward_sample_condition_and_explore(
target::ScaledPrecisionNormalPath,
explorer,
rng::SplittableRandom;
run_explorer::Bool = true
explorer = nothing
)
state = initialization(target, rng, 1) # forward simulation
if run_explorer
if !isnothing(explorer)
state = explorer_step(rng, target, explorer, state)
end
return state
Expand Down
1 change: 1 addition & 0 deletions src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ include("explorers/AAPS.jl")
include("explorers/GradientBasedSampler.jl")
include("evidence/stepping_stone.jl")
include("api.jl")
include("paths/JuliaBUGSPath.jl")
11 changes: 9 additions & 2 deletions src/log_potentials/log_potentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,12 @@ Assumes the input `log_potentials` is a vector where each element is a [`log_pot
This default implementation is sufficient in most cases, but in less standard scenarios,
e.g. where the state space is infinite dimensional, this can be overridden.
"""
log_unnormalized_ratio(log_potentials::AbstractVector, numerator::Int, denominator::Int, state) =
log_potentials[numerator](state) - log_potentials[denominator](state)
function log_unnormalized_ratio(log_potentials::AbstractVector, numerator::Int, denominator::Int, state)
lp_num = log_potentials[numerator](state)
lp_den = log_potentials[denominator](state)
ans = lp_num-lp_den
if isnan(ans)
error("Got NaN log-unnormalized ratio; Dumping information:\n\tlp_num=$lp_num\n\tlp_den=$lp_den\n\tState=$state")
end
return ans
end
Loading
Loading