Skip to content

Commit

Permalink
Final improvements to NUTS figures
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed Dec 5, 2023
1 parent 7ad17cc commit 2b40a41
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/naomi-simple_mcmc/mcmc-convergence.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ mcmc_summary <- summary(mcmc)$summary
(rhat_upper <- quantile(mcmc_summary[, "Rhat"], 0.975))
(rhat_max <- max(mcmc_summary[, "Rhat"]))
(rhat_mean <- mean(mcmc_summary[, "Rhat"]))
(rhat_above_threshold <- mean(mcmc_summary[, "Rhat"] < 1.01))
```

# ESS
Expand All @@ -105,7 +106,7 @@ ess_fig <- bayesplot::mcmc_neff_data(ratios) %>%
geom_vline(xintercept = 0.1, linetype = "dashed", col = "grey40", alpha = 0.5) +
geom_vline(xintercept = 0.5, linetype = "dashed", col = "grey40", alpha = 0.5) +
geom_vline(xintercept = 1, linetype = "dashed", col = "grey40", alpha = 0.5) +
labs(x = "ESS ratio", y = "Parameter", col = "", tag = "A") +
labs(x = "Effective sample size ratio", y = "Parameter", col = "", tag = "A") +
theme_minimal() +
theme(
axis.text.y = element_blank(),
Expand All @@ -122,8 +123,8 @@ What are the total effective sample sizes?
ess_histogram <- data.frame(mcmc_summary) %>%
tibble::rownames_to_column("param") %>%
ggplot(aes(x = n_eff)) +
geom_histogram(alpha = 0.8) +
labs(x = "ESS", y = "Count", tag = "B")
geom_histogram(col = "grey60", fill = "grey80") +
labs(x = "Effective sample size", y = "Count", tag = "B")
ess_fig / ess_histogram
Expand All @@ -147,6 +148,7 @@ out <- list(
"rhat_upper" = rhat_upper,
"rhat_max" = rhat_max,
"rhat_mean" = rhat_mean,
"rhat_above_threshold" = rhat_above_threshold,
"ess_min" = ess_min,
"ess_lower" = ess_lower,
"ess_median" = ess_median,
Expand Down Expand Up @@ -180,9 +182,12 @@ bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("log_sigma")))
The parameters with the worst ESS and the worst $\hat R$:

```{r}
(plot <- bayesplot::mcmc_trace(mcmc, pars = c(names(which.min(mcmc_summary[, "n_eff"])), names(which.max(rhats)))))
worst_eff <- bayesplot::mcmc_trace(mcmc, pars = names(which.min(mcmc_summary[, "n_eff"]))) + labs(tag = "A") + guides(col = "none")
worst_rhat <- bayesplot::mcmc_trace(mcmc, pars = names(which.max(rhats))) + labs(tag = "B")
ggsave("worst-trace.png", plot, h = 3, w = 6.25)
worst_eff + worst_rhat
ggsave("worst-trace.png", h = 3, w = 6.25)
```

## Prevalence model
Expand Down

0 comments on commit 2b40a41

Please sign in to comment.