-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Issue 559: Diagnostic analysis over all inference runs (#560)
* Create make_mcmc_diagnostic_dataframe.jl * reorg scripts and add more success/fail analysis * Add function to get run info to avoid DRY * Add function to do diagnostics * export new func * update SI * Issue 561: Soft min transformation (#562) Also removed unnecessary call to `fetch` * base values on pipeline types * breakdown mcmc convergence test function Adds more stats and a unit test
- Loading branch information
1 parent
9b467f2
commit 7dde0e9
Showing
14 changed files
with
265 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
## Analysis of the prediction dataframes for mcmc diagnostics | ||
diagnostic_df = mapreduce(vcat, scenarios) do scenario | ||
mapreduce(vcat, true_gi_means) do true_gi_mean | ||
target_str = "truth_gi_mean_" * string(true_gi_mean) * "_" | ||
files = readdir(datadir("epiaware_observables/" * scenario)) |> | ||
strs -> filter(s -> occursin("jld2", s), strs) |> | ||
strs -> filter(s -> occursin(target_str, s), strs) | ||
|
||
mapreduce(vcat, files) do filename | ||
output = load(joinpath(datadir("epiaware_observables"), scenario, filename)) | ||
try | ||
make_mcmc_diagnostic_dataframe(output, true_gi_mean, scenario) | ||
catch e | ||
end | ||
end | ||
end | ||
end | ||
|
||
## Save the mcmc diagnostics | ||
CSV.write("manuscript/inference_diagnostics_rnd2.csv", diagnostic_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
include("config_mappings.jl") | ||
include("make_truthdata_dataframe.jl") | ||
include("make_prediction_dataframe_from_output.jl") | ||
include("make_scoring_dataframe_from_output.jl") | ||
include("make_mcmc_diagnostic_dataframe.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
Extracts and returns relevant information from the given inference configuration dictionary. | ||
# Returns | ||
- `NamedTuple`: A named tuple containing the following fields: | ||
- `igp_model::String`: The IGP model name extracted from the configuration. | ||
- `latent_model::String`: The latent model name from the configuration. | ||
- `used_gi_mean::Float64`: The mean generation interval (GI) used in the configuration. | ||
- `used_gi_std::Float64`: The standard deviation of the generation interval (GI) used in the configuration. | ||
- `start_time::Int`: The start time parsed from the configuration's time span. | ||
- `reference_time::Int`: The reference time parsed from the configuration's time span. | ||
- `used_gi_means::Vector{Float64}`: A vector of GI means, either a single value if the IGP model is "Renewal" or a list of values generated by `make_gi_params` otherwise. | ||
""" | ||
function _get_info_from_config(inference_config) | ||
#Get the scenario, IGP model, latent model and true mean GI | ||
igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end] | ||
latent_model = inference_config["latent_model"] | ||
used_gi_mean = inference_config["gi_mean"] | ||
used_gi_std = inference_config["gi_std"] | ||
(start_time, reference_time) = inference_config["tspan"] |> | ||
tspan -> split(tspan, "_") |> | ||
tspan -> ( | ||
parse(Int, tspan[1]), parse(Int, tspan[2])) | ||
|
||
#Get the quantiles for the targets across the gi mean scenarios | ||
#if Renewal model, then we use the underlying epi model | ||
#otherwise we use the epi datas to loop over different gi mean implications | ||
used_gi_means = igp_model == "Renewal" ? | ||
[used_gi_mean] : | ||
make_gi_params(EpiAwareExamplePipeline())["gi_means"] | ||
return (; igp_model, latent_model, used_gi_mean, used_gi_std, | ||
start_time, reference_time, used_gi_means) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
""" | ||
Collects the statistics of a vector `x` that are relevant for MCMC diagnostics. | ||
""" | ||
function _get_stats(x, threshold; pass_above = true) | ||
if pass_above | ||
return (; x_mean = mean(x), prop_pass = mean(x .>= threshold), | ||
x_min = minimum(x), x_max = maximum(x)) | ||
else | ||
return (; x_mean = mean(x), prop_pass = mean(x .<= threshold), | ||
x_min = minimum(x), x_max = maximum(x)) | ||
end | ||
end | ||
|
||
""" | ||
Collects the convergence statistics over the parameters that are not cluster factor. | ||
""" | ||
function _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold, | ||
tail_ess_threshold, rhat_diff_threshold) | ||
ess_bulk = chn_nt.ess_bulk[not_cluster_factor] |> x -> _get_stats(x, bulk_ess_threshold) | ||
ess_tail = chn_nt.ess_tail[not_cluster_factor] |> x -> _get_stats(x, tail_ess_threshold) | ||
rhat_diff = abs.(chn_nt.rhat[not_cluster_factor] .- 1) |> | ||
x -> _get_stats(x, rhat_diff_threshold; pass_above = false) | ||
return (; ess_bulk, ess_tail, rhat_diff) | ||
end | ||
|
||
""" | ||
Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of | ||
parameters that pass the bulk effective sample size (ESS) threshold, the proportion of | ||
parameters that pass the tail ESS threshold, the proportion of parameters that pass the R-hat | ||
absolute difference from 1 threshold, whether the model has a cluster factor parameter, and the tail ESS | ||
of the cluster factor parameter. | ||
# Arguments | ||
- `output::Dict`: A dictionary containing the inference results. | ||
- `bulk_ess_threshold::Int`: The threshold for bulk effective sample size (ESS). Default is 500. | ||
- `tail_ess_threshold::Int`: The threshold for tail effective sample size (ESS). Default is 100. | ||
- `rhat_diff_threshold::Float64`: The threshold for the difference of R-hat from 1. Default is 0.02. | ||
""" | ||
function make_mcmc_diagnostic_dataframe( | ||
output, true_mean_gi, scenario; bulk_ess_threshold = 500, | ||
tail_ess_threshold = 100, rhat_diff_threshold = 0.02) | ||
#Get the scenario, IGP model, latent model and true mean GI | ||
inference_config = output["inference_config"] | ||
info = _get_info_from_config(inference_config) | ||
#Get the convergence diagnostics | ||
chn_nt = output["inference_results"].samples |> summarize |> summary -> summary.nt | ||
cluster_factor_idxs = chn_nt.parameters .== Symbol("obs.cluster_factor") | ||
has_cluster_factor = any(cluster_factor_idxs) | ||
not_cluster_factor = .~cluster_factor_idxs | ||
cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1] | ||
#Collect the statistics | ||
stats_for_targets = _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold, | ||
tail_ess_threshold, rhat_diff_threshold) | ||
|
||
#Create the dataframe | ||
df = mapreduce(vcat, info.used_gi_means) do used_gi_mean | ||
DataFrame( | ||
Scenario = scenario, | ||
igp_model = info.igp_model, | ||
latent_model = info.latent_model, | ||
True_GI_Mean = true_mean_gi, | ||
used_gi_mean = used_gi_mean, | ||
reference_time = info.reference_time, | ||
has_cluster_factor = has_cluster_factor, | ||
cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing) | ||
end | ||
#Add stats columns | ||
for key in keys(stats_for_targets) | ||
stats = getfield(stats_for_targets, key) | ||
df[!, string(key) * "_" * "mean"] .= stats.x_mean | ||
df[!, string(key) * "_" * "prop_pass"] .= stats.prop_pass | ||
df[!, string(key) * "_" * "min"] .= stats.x_min | ||
df[!, string(key) * "_" * "max"] .= stats.x_max | ||
end | ||
return df | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.