Skip to content

Commit

Permalink
breakdown mcmc convergence test function
Browse files Browse the repository at this point in the history
Adds more stats and a unit test
  • Loading branch information
SamuelBrand1 committed Dec 19, 2024
1 parent 73e80c1 commit 727a827
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
44 changes: 36 additions & 8 deletions pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
"""
Collects the statistics of a vector `x` that are relevant for MCMC diagnostics.
"""
function _get_stats(x, threshold; pass_above = true)
if pass_above
return (; x_mean = mean(x), prop_pass = mean(x .>= threshold),
x_min = minimum(x), x_max = maximum(x))
else
return (; x_mean = mean(x), prop_pass = mean(x .<= threshold),
x_min = minimum(x), x_max = maximum(x))
end
end

"""
Collects the convergence statistics over the parameters that are not cluster factor.
"""
function _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold,
tail_ess_threshold, rhat_diff_threshold)
ess_bulk = chn_nt.ess_bulk[not_cluster_factor] |> x -> _get_stats(x, bulk_ess_threshold)
ess_tail = chn_nt.ess_tail[not_cluster_factor] |> x -> _get_stats(x, tail_ess_threshold)
rhat_diff = abs.(chn_nt.rhat[not_cluster_factor] .- 1) |>
x -> _get_stats(x, rhat_diff_threshold; pass_above = false)
return (; ess_bulk, ess_tail, rhat_diff)
end

"""
Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of
parameters that pass the bulk effective sample size (ESS) threshold, the proportion of
Expand All @@ -23,6 +48,9 @@ function make_mcmc_diagnostic_dataframe(
has_cluster_factor = any(cluster_factor_idxs)
not_cluster_factor = .~cluster_factor_idxs
cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1]
#Collect the statistics
stats_for_targets = _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold,
tail_ess_threshold, rhat_diff_threshold)

#Create the dataframe
df = mapreduce(vcat, info.used_gi_means) do used_gi_mean
Expand All @@ -33,16 +61,16 @@ function make_mcmc_diagnostic_dataframe(
True_GI_Mean = true_mean_gi,
used_gi_mean = used_gi_mean,
reference_time = info.reference_time,
bulk_ess_threshold = (chn_nt.ess_bulk[not_cluster_factor] .>
bulk_ess_threshold) |>
mean,
tail_ess_threshold = (chn_nt.ess_tail[not_cluster_factor] .>
tail_ess_threshold) |>
mean,
rhat_diff_threshold = (abs.(chn_nt.rhat[not_cluster_factor] .- 1) .<
rhat_diff_threshold) |> mean,
has_cluster_factor = has_cluster_factor,
cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing)
end
#Add stats columns
for key in keys(stats_for_targets)
stats = getfield(stats_for_targets, key)
df[!, string(key) * "_" * "mean"] .= stats.x_mean
df[!, string(key) * "_" * "prop_pass"] .= stats.prop_pass
df[!, string(key) * "_" * "min"] .= stats.x_min
df[!, string(key) * "_" * "max"] .= stats.x_max
end
return df
end
37 changes: 37 additions & 0 deletions pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
@testset "test MCMC convergence analysis on toy obs model" begin
using JLD2, DataFramesMeta, Turing, EpiAware
# Reuse the local config
_output = load(joinpath(@__DIR__(), "test_data.jld2"))
inference_config = _output["inference_config"]
# Create a simple test model to test mcmc diagnostics via prior sampling
obs = make_observation_model(SmoothEndemicPipeline())
@model function test_model()
x ~ filldist(Normal(0, 1), 20)
@submodel prefix="obs" y_t=generate_observations(obs, missing, exp.(x))
end
n = 1000
samples = sample(test_model(), Prior(), n)

# Create a simple output to test the function
output = Dict(
"inference_config" => inference_config,
"inference_results" => (; samples,)
)

true_mean_gi = 10.0
scenario = "rough_endemic"
df = make_mcmc_diagnostic_dataframe(
output, true_mean_gi, "rough_endemic")
# Check pass throughs
@test typeof(df) == DataFrame
@test size(df, 1) == 3 # Number of rows should match the length of used_gi_means
@test df[1, :Scenario] == scenario
@test df[1, :latent_model] == inference_config["latent_model"]
@test df[1, :True_GI_Mean] == true_mean_gi
# Prior sampling should be uncorrelated and meet all the convergence criteria
@test all(df[:, :ess_bulk_prop_pass] .== 1.0)
@test all(df[:, :ess_tail_prop_pass] .== 1.0)
@test all(df[:, :rhat_diff_prop_pass] .== 1.0)
@test all(df[:, :has_cluster_factor] .== true)
@test all(df[1, :cluster_factor_tail] .> n / 2)
end
1 change: 1 addition & 0 deletions pipeline/test/analysis/test_analysis.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("make_prediction_dataframe_from_output.jl")
include("make_truthdata_dataframe.jl")
include("make_mcmc_diagnostic_dataframe.jl")

0 comments on commit 727a827

Please sign in to comment.