diff --git a/src/utils/inferencedata.jl b/src/utils/inferencedata.jl index 3601ebb..a37e777 100644 --- a/src/utils/inferencedata.jl +++ b/src/utils/inferencedata.jl @@ -1,66 +1,74 @@ using .InferenceObjects +const SymbolOrSymbols = Union{Symbol, AbstractVector{Symbol}, NTuple{N, Symbol} where N} + +# Define the "proper" ArviZ names for the sample statistics group. +const SAMPLE_STATS_KEY_MAP = ( + n_leapfrog__=:n_steps, + treedepth__=:tree_depth, + energy__=:energy, + lp__=:lp, + stepsize__=:step_size, + divergent__=:diverging, + accept_stat__=:acceptance_rate, +) + +function split_nt(nt::NamedTuple, ks::NTuple{N, Symbol}) where {N} + keys1 = filter(∉(ks), keys(nt)) + keys2 = filter(∈(ks), keys(nt)) + return NamedTuple{keys1}(nt), NamedTuple{keys2}(nt) +end +split_nt(nt::NamedTuple, key::Symbol) = split_nt(nt, (key,)) +split_nt(nt::NamedTuple, ::Nothing) = (nt, nothing) +split_nt(nt::NamedTuple, keys) = split_nt(nt, Tuple(keys)) + +function split_nt_all(nt::NamedTuple, pair::Pair{Symbol}, others::Pair{Symbol}...) + group_name, keys = pair + nt_main, nt_group = split_nt(nt, keys) + post_nt, groups_nt_others = split_nt_all(nt_main, others...) + groups_nt = NamedTuple{(group_name,)}((nt_group,)) + return post_nt, merge(groups_nt, groups_nt_others) +end +split_nt_all(nt::NamedTuple) = (nt, NamedTuple()) + +function rekey(d::NamedTuple, keymap) + new_keys = map(k -> get(keymap, k, k), keys(d)) + return NamedTuple{new_keys}(values(d)) +end + function inferencedata1(m::SampleModel; include_warmup = m.save_warmup, - log_likelihood_symbol::Union{Nothing, Symbol} = :log_lik, - posterior_predictive_symbol::Union{Nothing, Symbol} = :y_hat, - kwargs...) + log_likelihood_var::Union{SymbolOrSymbols,Nothing} = nothing, + posterior_predictive_var::Union{SymbolOrSymbols,Nothing} = nothing, + predictions_var::Union{SymbolOrSymbols,Nothing} = nothing, + kwargs..., +) # Read in the draws as a NamedTuple with sample_stats included stan_nts = read_samples(m, :namedtuples; include_internals=true) - # Define the "proper" ArviZ names for the sample statistics group. - sample_stats_key_map = ( - n_leapfrog__=:n_steps, - treedepth__=:tree_depth, - energy__=:energy, - lp__=:lp, - stepsize__=:step_size, - divergent__=:diverging, - accept_stat__=:acceptance_rate, - ); - - # If a log_likelihood_symbol is defined (!= nothing), remove it from the future posterior group - if !isnothing(log_likelihood_symbol) - sample_nts = NamedTuple{filter(∉([log_likelihood_symbol]), keys(stan_nts))}(stan_nts) - else - sample_mts = stan_nts - end - - # If a posterior_predictive_symbol is defined (!= nothing), remove it from the future posterior group - if !isnothing(posterior_predictive_symbol) - sample_nts = NamedTuple{filter(∉([posterior_predictive_symbol]), keys(sample_nts))}(sample_nts) - end - - # `sample_nts` now holds remaining parameters and the sample statistics - # Split in 2 separate NamedTuples: posterior_nts and sample_stats_nts - posterior_nts = NamedTuple{filter(∉(keys(sample_stats_key_map)), keys(sample_nts))}(sample_nts) - sample_stats_nts = NamedTuple{filter(∈(keys(sample_stats_key_map)), keys(sample_nts))}(sample_nts) - - # Remap the names according to above sample_stats_key_map - sample_stats_nts_rekey = - NamedTuple{map(Base.Fix1(getproperty, sample_stats_key_map), keys(sample_stats_nts))}( - values(sample_stats_nts)) + # split stan_nts into separate groups based on keyword arguments + posterior_nts, group_nts = split_nt_all( + stan_nts, + :sample_stats => keys(SAMPLE_STATS_KEY_MAP), + :log_likelihood => log_likelihood_var, + :posterior_predictive => posterior_predictive_var, + :predictions => predictions_var, + ) + # Remap the names according to above SAMPLE_STATS_KEY_MAP + sample_stats = rekey(group_nts.sample_stats, SAMPLE_STATS_KEY_MAP) + group_nts_stats_rename = merge(group_nts, (; sample_stats=sample_stats)) # Create initial inferencedata object with 2 groups - idata = from_namedtuple(posterior_nts; sample_stats=sample_stats_nts_rekey, kwargs...) - - # Merge both log_likelihood and posterior_predictive groups into idata if present - if !isnothing(posterior_predictive_symbol) && posterior_predictive_symbol in keys(stan_nts) - nt = (y = stan_nts[posterior_predictive_symbol],) - idata = merge(idata, from_namedtuple(nt; posterior_predictive = (:y,))) - end - - if !isnothing(log_likelihood_symbol) log_likelihood_symbol in keys(stan_nts) - nt = (y = stan_nts[log_likelihood_symbol],) - idata = merge(idata, from_namedtuple(nt; log_likelihood = (:y,))) - end + idata = from_namedtuple(posterior_nts; group_nts_stats_rename..., kwargs...) # Extract warmup values in separate groups if include_warmup + warmup_indices = 1:m.num_warmups + sample_indices = (1:m.num_samples) .+ m.num_warmups idata = let - idata_warmup = idata[draw=1:1000] - idata_postwarmup = idata[draw=1001:2000] + idata_warmup = idata[draw=warmup_indices] + idata_postwarmup = idata[draw=sample_indices] idata_warmup_rename = InferenceData(NamedTuple(Symbol("warmup_$k") => idata_warmup[k] for k in keys(idata_warmup))) merge(idata_postwarmup, idata_warmup_rename) @@ -68,6 +76,7 @@ function inferencedata1(m::SampleModel; end # TO DO: update the indexing + # TO DO: import observed and constant data return idata end @@ -322,10 +331,11 @@ $(SIGNATURES) ### Optional positional argument ```julia -* `include_warmup` # Directory where output files are stored -* `log_likelihood_symbol` # Symbol used for log_likelihood (or nothing, default: :log_lik) -* `posterior_predictive_symbol` # Symbol used for posterior_predictive (or nothing, default: :y_hat) -* `kwargs...` # Arguments to pass on to `from_namedtuple` +* `include_warmup` # Directory where output files are stored +* `log_likelihood_var` # Symbol(s) used for log_likelihood (or nothing) +* `posterior_predictive_var` # Symbol(s) used for posterior_predictive (or nothing) +* `predictions_var` # Symbol(s) used for predictions (or nothing) +* `kwargs...` # Arguments to pass on to `from_namedtuple` ``` ### Returns @@ -338,7 +348,7 @@ See the example in ./test/test_inferencedata.jl. Note that this function is currently under development. """ -inferencedata = inferencedata3 +inferencedata = inferencedata1 export inferencedata