Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to inferencedata method #67

Merged
merged 10 commits into from
Dec 9, 2022
118 changes: 64 additions & 54 deletions src/utils/inferencedata.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,82 @@
using .InferenceObjects

const SymbolOrSymbols = Union{Symbol, AbstractVector{Symbol}, NTuple{N, Symbol} where N}

# Define the "proper" ArviZ names for the sample statistics group.
const SAMPLE_STATS_KEY_MAP = (
n_leapfrog__=:n_steps,
treedepth__=:tree_depth,
energy__=:energy,
lp__=:lp,
stepsize__=:step_size,
divergent__=:diverging,
accept_stat__=:acceptance_rate,
)

function split_nt(nt::NamedTuple, ks::NTuple{N, Symbol}) where {N}
keys1 = filter(∉(ks), keys(nt))
keys2 = filter(∈(ks), keys(nt))
return NamedTuple{keys1}(nt), NamedTuple{keys2}(nt)
end
split_nt(nt::NamedTuple, key::Symbol) = split_nt(nt, (key,))
split_nt(nt::NamedTuple, ::Nothing) = (nt, nothing)
split_nt(nt::NamedTuple, keys) = split_nt(nt, Tuple(keys))

function split_nt_all(nt::NamedTuple, pair::Pair{Symbol}, others::Pair{Symbol}...)
group_name, keys = pair
nt_main, nt_group = split_nt(nt, keys)
post_nt, groups_nt_others = split_nt_all(nt_main, others...)
groups_nt = NamedTuple{(group_name,)}((nt_group,))
return post_nt, merge(groups_nt, groups_nt_others)
end
split_nt_all(nt::NamedTuple) = (nt, NamedTuple())

function rekey(d::NamedTuple, keymap)
new_keys = map(k -> get(keymap, k, k), keys(d))
return NamedTuple{new_keys}(values(d))
end

function inferencedata1(m::SampleModel;
include_warmup = m.save_warmup,
log_likelihood_symbol::Union{Nothing, Symbol} = :log_lik,
posterior_predictive_symbol::Union{Nothing, Symbol} = :y_hat,
kwargs...)
log_likelihood_var::Union{SymbolOrSymbols,Nothing} = nothing,
posterior_predictive_var::Union{SymbolOrSymbols,Nothing} = nothing,
predictions_var::Union{SymbolOrSymbols,Nothing} = nothing,
kwargs...,
)

# Read in the draws as a NamedTuple with sample_stats included
stan_nts = read_samples(m, :namedtuples; include_internals=true)

# Define the "proper" ArviZ names for the sample statistics group.
sample_stats_key_map = (
n_leapfrog__=:n_steps,
treedepth__=:tree_depth,
energy__=:energy,
lp__=:lp,
stepsize__=:step_size,
divergent__=:diverging,
accept_stat__=:acceptance_rate,
);

# If a log_likelihood_symbol is defined (!= nothing), remove it from the future posterior group
if !isnothing(log_likelihood_symbol)
sample_nts = NamedTuple{filter(∉([log_likelihood_symbol]), keys(stan_nts))}(stan_nts)
else
sample_mts = stan_nts
end

# If a posterior_predictive_symbol is defined (!= nothing), remove it from the future posterior group
if !isnothing(posterior_predictive_symbol)
sample_nts = NamedTuple{filter(∉([posterior_predictive_symbol]), keys(sample_nts))}(sample_nts)
end

# `sample_nts` now holds remaining parameters and the sample statistics
# Split in 2 separate NamedTuples: posterior_nts and sample_stats_nts
posterior_nts = NamedTuple{filter(∉(keys(sample_stats_key_map)), keys(sample_nts))}(sample_nts)
sample_stats_nts = NamedTuple{filter(∈(keys(sample_stats_key_map)), keys(sample_nts))}(sample_nts)

# Remap the names according to above sample_stats_key_map
sample_stats_nts_rekey =
NamedTuple{map(Base.Fix1(getproperty, sample_stats_key_map), keys(sample_stats_nts))}(
values(sample_stats_nts))
# split stan_nts into separate groups based on keyword arguments
posterior_nts, group_nts = split_nt_all(
stan_nts,
:sample_stats => keys(SAMPLE_STATS_KEY_MAP),
:log_likelihood => log_likelihood_var,
:posterior_predictive => posterior_predictive_var,
:predictions => predictions_var,
)
# Remap the names according to above SAMPLE_STATS_KEY_MAP
sample_stats = rekey(group_nts.sample_stats, SAMPLE_STATS_KEY_MAP)
group_nts_stats_rename = merge(group_nts, (; sample_stats=sample_stats))

# Create initial inferencedata object with 2 groups
idata = from_namedtuple(posterior_nts; sample_stats=sample_stats_nts_rekey, kwargs...)

# Merge both log_likelihood and posterior_predictive groups into idata if present
if !isnothing(posterior_predictive_symbol) && posterior_predictive_symbol in keys(stan_nts)
nt = (y = stan_nts[posterior_predictive_symbol],)
idata = merge(idata, from_namedtuple(nt; posterior_predictive = (:y,)))
end

if !isnothing(log_likelihood_symbol) log_likelihood_symbol in keys(stan_nts)
nt = (y = stan_nts[log_likelihood_symbol],)
idata = merge(idata, from_namedtuple(nt; log_likelihood = (:y,)))
end
idata = from_namedtuple(posterior_nts; group_nts_stats_rename..., kwargs...)

# Extract warmup values in separate groups
if include_warmup
warmup_indices = 1:m.num_warmups
sample_indices = (1:m.num_samples) .+ m.num_warmups
idata = let
idata_warmup = idata[draw=1:1000]
idata_postwarmup = idata[draw=1001:2000]
idata_warmup = idata[draw=warmup_indices]
idata_postwarmup = idata[draw=sample_indices]
idata_warmup_rename = InferenceData(NamedTuple(Symbol("warmup_$k") => idata_warmup[k] for k in
keys(idata_warmup)))
merge(idata_postwarmup, idata_warmup_rename)
end
end

# TO DO: update the indexing
# TO DO: import observed and constant data

return idata
end
Expand Down Expand Up @@ -322,10 +331,11 @@ $(SIGNATURES)

### Optional positional argument
```julia
* `include_warmup` # Directory where output files are stored
* `log_likelihood_symbol` # Symbol used for log_likelihood (or nothing, default: :log_lik)
* `posterior_predictive_symbol` # Symbol used for posterior_predictive (or nothing, default: :y_hat)
* `kwargs...` # Arguments to pass on to `from_namedtuple`
* `include_warmup` # Directory where output files are stored
* `log_likelihood_var` # Symbol(s) used for log_likelihood (or nothing)
* `posterior_predictive_var` # Symbol(s) used for posterior_predictive (or nothing)
* `predictions_var` # Symbol(s) used for predictions (or nothing)
* `kwargs...` # Arguments to pass on to `from_namedtuple`
```

### Returns
Expand All @@ -338,7 +348,7 @@ See the example in ./test/test_inferencedata.jl.
Note that this function is currently under development.

"""
inferencedata = inferencedata3
inferencedata = inferencedata1

export
inferencedata