-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #303 from Julia-Tempering/juliabugs
JuliaBUGS support
- Loading branch information
Showing
15 changed files
with
377 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
function incomplete_count_data_model(;tau::Real=4) | ||
model_def = @bugs("model{ | ||
for (i in 1:n) { | ||
r[i] ~ dbern(pr[i]) | ||
pr[i] <- ilogit(y[i] * alpha1 + alpha0) | ||
y[i] ~ dpois(mu) | ||
} | ||
mu ~ dgamma(1,1) | ||
alpha0 ~ dnorm(0, 0.1) | ||
alpha1 ~ dnorm(0, tau) | ||
}",false,false | ||
) | ||
data = ( | ||
y = [ | ||
6,missing,missing,missing,missing,missing,missing,5,1,missing,1,missing, | ||
missing,missing,2,missing,missing,0,missing,1,2,1,7,4,6,missing,missing, | ||
missing,5,missing | ||
], | ||
r = [1,0,0,0,0,0,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,1,1,1,1,0,0,0,1,0], | ||
n = 30, tau = tau | ||
) | ||
return compile(model_def, data) | ||
end | ||
incomplete_count_data(;kwargs...) = JuliaBUGSPath(incomplete_count_data_model(;kwargs...)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
module PigeonsJuliaBUGSExt | ||
|
||
using Pigeons | ||
if isdefined(Base, :get_extension) | ||
import JuliaBUGS | ||
using AbstractPPL # only need because we rewrite JuliaBUGS.getparams | ||
using Bijectors # only need because we rewrite JuliaBUGS.getparams | ||
using DocStringExtensions | ||
using SplittableRandoms: SplittableRandom, split | ||
using Random | ||
else | ||
import ..JuliaBUGS | ||
using ..AbstractPPL # only need because we rewrite JuliaBUGS.getparams | ||
using ..Bijectors # only need because we rewrite JuliaBUGS.getparams | ||
using ..DocStringExtensions | ||
using ..SplittableRandoms: SplittableRandom, split | ||
using ..Random | ||
end | ||
|
||
include(joinpath(@__DIR__, "utils.jl")) | ||
include(joinpath(@__DIR__, "interface.jl")) | ||
include(joinpath(@__DIR__, "invariance_test.jl")) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
####################################### | ||
# Path interface | ||
####################################### | ||
|
||
# Initialization and iid sampling | ||
function evaluate_and_initialize(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) | ||
new_env = first(JuliaBUGS.evaluate!!(rng, model)) # sample a new evaluation environment | ||
return JuliaBUGS.initialize!(model, new_env) # set the private_model's environment to the newly created one | ||
end | ||
|
||
# used for both initializing and iid sampling | ||
# Note: state is a flattened vector of the parameters | ||
# Also, the vector is **concretely typed**. This means that if the evaluation | ||
# environment contains floats and integers, the latter will be cast to float. | ||
_sample_iid(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) = | ||
getparams(evaluate_and_initialize(model, rng)) # flatten the unobserved parameters in the model's eval environment and return | ||
|
||
# Note: JuliaBUGS.getparams creates a new vector on each call, so it is safe | ||
# to call _sample_iid during initialization (**sequentially**, as done as of time | ||
# of writing) for different Replicas (i.e., they won't share the same state). | ||
Pigeons.initialization(target::JuliaBUGSPath, rng::AbstractRNG, _::Int64) = | ||
_sample_iid(target.model, rng) | ||
|
||
# target is already a Path | ||
Pigeons.create_path(target::JuliaBUGSPath, ::Inputs) = target | ||
|
||
####################################### | ||
# Log-potential interface | ||
####################################### | ||
|
||
""" | ||
$SIGNATURES | ||
A log-potential built from a [`JuliaBUGSPath`](@ref) for a specific inverse | ||
temperature parameter. | ||
$FIELDS | ||
""" | ||
struct JuliaBUGSLogPotential{TMod<:JuliaBUGS.BUGSModel, TF<:AbstractFloat} | ||
""" | ||
A deep-enough copy of the original model that allows evaluation while | ||
avoiding race conditions between different Replicas. | ||
""" | ||
private_model::TMod | ||
|
||
""" | ||
Tempering parameter. | ||
""" | ||
beta::TF | ||
end | ||
|
||
# make a log-potential by creating a new model with independent graph and | ||
# evaluation environment. Both of these could be modified during density | ||
# evaluations and/or during Gibbs sampling | ||
function Pigeons.interpolate(path::JuliaBUGSPath, beta) | ||
model = path.model | ||
private_model = make_private_model_copy(model) | ||
JuliaBUGSLogPotential(private_model, beta) | ||
end | ||
|
||
# log_potential evaluation | ||
(log_potential::JuliaBUGSLogPotential)(flattened_values) = | ||
try | ||
log_prior, _, tempered_log_joint = last( | ||
JuliaBUGS._tempered_evaluate!!( | ||
log_potential.private_model, | ||
flattened_values; | ||
temperature=log_potential.beta | ||
) | ||
) | ||
# avoid potential 0*Inf (= NaN) | ||
return iszero(log_potential.beta) ? log_prior : tempered_log_joint | ||
catch e | ||
(isa(e, DomainError) || isa(e, BoundsError)) && return -Inf | ||
rethrow(e) | ||
end | ||
|
||
# iid sampling | ||
function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, shared) | ||
replica.state = _sample_iid(log_potential.private_model, replica.rng) | ||
end | ||
|
||
# parameter names | ||
Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) = | ||
[(Symbol(string(vn)) for vn in log_potential.private_model.parameters)...,:log_density] | ||
|
||
# Parallelism invariance | ||
Pigeons.recursive_equal(a::Union{JuliaBUGSPath,JuliaBUGSLogPotential}, b) = | ||
Pigeons._recursive_equal(a,b) | ||
function Pigeons.recursive_equal(a::T, b) where T <: JuliaBUGS.BUGSModel | ||
included = (:transformed, :model_def, :data) | ||
excluded = Tuple(setdiff(fieldnames(T), included)) | ||
Pigeons._recursive_equal(a,b,excluded) | ||
end | ||
# just check the betas match, the model is already checked within path | ||
Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) = | ||
all(lp1.beta == lp2.beta for (lp1,lp2) in zip(a,b)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
$SIGNATURES | ||
Implements `Pigeons.forward_sample_condition_and_explore` for running invariance | ||
tests using a [`JuliaBUGSPath`](@ref) as target. | ||
""" | ||
function Pigeons.forward_sample_condition_and_explore( | ||
model::JuliaBUGS.BUGSModel, | ||
rng::SplittableRandom; | ||
explorer = nothing, | ||
condition_on = () | ||
) | ||
# forward simulation (new values stored in model.evaluation_env) | ||
model = evaluate_and_initialize(model, rng) | ||
|
||
# maybe condition the model using the sampled observations | ||
conditioned_model = if length(condition_on) > 0 | ||
var_group = [JuliaBUGS.VarName{sym}() for sym in condition_on] # transform Symbols into VarNames | ||
JuliaBUGS.condition(model, var_group) | ||
else | ||
model | ||
end | ||
|
||
# maybe take a step with explorer | ||
state = getparams(conditioned_model) | ||
return if !isnothing(explorer) | ||
Pigeons.explorer_step(rng, JuliaBUGSPath(conditioned_model), explorer, state) | ||
else | ||
state | ||
end | ||
end | ||
|
||
Pigeons.forward_sample_condition_and_explore(target::JuliaBUGSPath, args...; kwargs...) = | ||
Pigeons.forward_sample_condition_and_explore(target.model, args...; kwargs...) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#= | ||
Tweak of JuliaBUGS.getparams to allow for flattened vectors of mixed type | ||
=# | ||
type_join_eval_env(env) = typejoin(Set(eltype(v) for v in env)...) | ||
function getparams(model::JuliaBUGS.BUGSModel) | ||
param_length = if model.transformed | ||
model.transformed_param_length | ||
else | ||
model.untransformed_param_length | ||
end | ||
|
||
# search for an umbrella type for all parameters in the model to avoid | ||
# promotion of e.g. ints to floats. For models with a unique parameter | ||
# type T, it holds that TMix=T. | ||
TMix = type_join_eval_env(model.evaluation_env) | ||
param_vals = Vector{TMix}(undef, param_length) | ||
pos = 1 | ||
for v in model.parameters | ||
if !model.transformed | ||
val = AbstractPPL.get(model.evaluation_env, v) | ||
len = model.untransformed_var_lengths[v] | ||
if val isa AbstractArray | ||
param_vals[pos:(pos + len - 1)] .= vec(val) | ||
else | ||
param_vals[pos] = val | ||
end | ||
else | ||
(; node_function, loop_vars) = model.g[v] | ||
dist = node_function(model.evaluation_env, loop_vars) | ||
transformed_value = Bijectors.transform( | ||
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v) | ||
) | ||
len = model.transformed_var_lengths[v] | ||
if transformed_value isa AbstractArray | ||
param_vals[pos:(pos + len - 1)] .= vec(transformed_value) | ||
else | ||
param_vals[pos] = transformed_value | ||
end | ||
end | ||
pos += len | ||
end | ||
return param_vals | ||
end | ||
|
||
function make_private_model_copy(model::JuliaBUGS.BUGSModel) | ||
g = deepcopy(model.g) | ||
parameters = model.parameters | ||
sorted_nodes = model.flattened_graph_node_data.sorted_nodes | ||
return JuliaBUGS.BUGSModel( | ||
model.transformed, | ||
sum(model.untransformed_var_lengths[v] for v in parameters), | ||
sum(model.transformed_var_lengths[v] for v in parameters), | ||
model.untransformed_var_lengths, | ||
model.transformed_var_lengths, | ||
deepcopy(model.evaluation_env), | ||
parameters, | ||
JuliaBUGS.FlattenedGraphNodeData(g, sorted_nodes), | ||
g, | ||
nothing, | ||
model.model_def, | ||
model.data | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.