Skip to content

Commit

Permalink
Implement make_chain_from_prior([rng,] model, n_iters)
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 29, 2024
1 parent 310490a commit 696bb3d
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,32 @@ function modify_value_representation(nt::NamedTuple)
end
return modified_nt
end

"""
make_chain_from_prior([rng,] model, n_iters)
Construct an MCMCChains.Chains object by sampling from the prior of `model` for
`n_iters` iterations.
"""
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
# Sample from the prior
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
# Convert each varinfo into an OrderedDict of vns => params.
# We have to use varname_and_value_leaves so that each parameter is a scalar
dicts = map(varinfos) do t
vals = DynamicPPL.values_as(t, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
tuples = mapreduce(collect, vcat, iters)
OrderedDict(tuples)
end
# Extract all varnames found in any dictionary. Doing it this way guards
# against the possibility of having different varnames in different
# dictionaries, e.g. for models that have dynamic variables / array sizes
all_varnames = collect(union(map(keys, dicts)...))
vals = [get(dict, vn, missing) for dict in dicts, vn in all_varnames]
# Construct and return the Chains object
return Chains(vals, all_varnames)
end
function make_chain_from_prior(model::Model, n_iters::Int)
return make_chain_from_prior(Random.default_rng(), model, n_iters)
end

0 comments on commit 696bb3d

Please sign in to comment.