diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R
index 571a6ce2..2e5f32e1 100644
--- a/R/mcmc-traces.R
+++ b/R/mcmc-traces.R
@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
#' of rank-normalized MCMC samples. Defaults to `20`.
#' @param ref_line For the rank plots, whether to draw a horizontal line at the
#' average number of ranks per bin. Defaults to `FALSE`.
+#' @param split_chains Logical indicating whether to split each chain into two parts.
+#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
+#' Defaults to `FALSE`.
#' @export
mcmc_rank_overlay <- function(x,
pars = character(),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
facet_args = list(),
...,
n_bins = 20,
- ref_line = FALSE) {
+ ref_line = FALSE,
+ split_chains = FALSE) {
check_ignored_arguments(...)
data <- mcmc_trace_data(
x,
@@ -294,7 +298,28 @@ mcmc_rank_overlay <- function(x,
transformations = transformations
)
- n_chains <- unique(data$n_chains)
+ # Split chains if requested
+ if (split_chains) {
+ data$n_chains = data$n_chains/2
+ data$n_iterations = data$n_iterations/2
+ # Calculate midpoint for each chain
+ n_samples <- length(unique(data$iteration))
+ midpoint <- n_samples/2
+
+ # Create new data frame with split chains
+ data <- data %>%
+ group_by(.data$chain) %>%
+ mutate(
+ chain = ifelse(
+ .data$iteration <= midpoint,
+ paste0(.data$chain, "_1"),
+ paste0(.data$chain, "_2")
+ )
+ ) %>%
+ ungroup()
+ }
+
+ n_chains <- length(unique(data$chain))
n_param <- unique(data$n_parameters)
# We have to bin and count the data ourselves because
@@ -319,6 +344,7 @@ mcmc_rank_overlay <- function(x,
bin_start = unique(histobins$bin_start),
stringsAsFactors = FALSE
))
+
d_bin_counts <- all_combos %>%
left_join(d_bin_counts, by = c("parameter", "chain", "bin_start")) %>%
mutate(n = dplyr::if_else(is.na(n), 0L, n))
@@ -331,7 +357,9 @@ mcmc_rank_overlay <- function(x,
mutate(bin_start = right_edge) %>%
dplyr::bind_rows(d_bin_counts)
- scale_color <- scale_color_manual("Chain", values = chain_colors(n_chains))
+ # Update legend title based on split_chains
+ legend_title <- if (split_chains) "Split Chains" else "Chain"
+ scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chains))
layer_ref_line <- if (ref_line) {
geom_hline(
@@ -352,7 +380,7 @@ mcmc_rank_overlay <- function(x,
}
ggplot(d_bin_counts) +
- aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
+ aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
geom_step() +
layer_ref_line +
facet_call +
@@ -457,6 +485,9 @@ mcmc_rank_hist <- function(x,
#' @param plot_diff For `mcmc_rank_ecdf()`, a boolean specifying if the
#' difference between the observed rank ECDFs and the theoretical expectation
#' should be drawn instead of the unmodified rank ECDF plots.
+#' @param split_chains Logical indicating whether to split each chain into two parts.
+#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
+#' Defaults to `FALSE`.
#' @export
mcmc_rank_ecdf <-
function(x,
@@ -468,7 +499,8 @@ mcmc_rank_ecdf <-
facet_args = list(),
prob = 0.99,
plot_diff = FALSE,
- interpolate_adj = NULL) {
+ interpolate_adj = NULL,
+ split_chains = FALSE) {
check_ignored_arguments(...,
ok_args = c("K", "pit", "prob", "plot_diff", "interpolate_adj", "M")
)
@@ -479,8 +511,28 @@ mcmc_rank_ecdf <-
transformations = transformations,
highlight = 1
)
+
+ # Split chains if requested
+ if (split_chains) {
+ data$n_chains = data$n_chains/2
+ data$n_iterations = data$n_iterations/2
+ n_samples <- length(unique(data$iteration))
+ midpoint <- n_samples/2
+
+ data <- data %>%
+ group_by(.data$chain) %>%
+ mutate(
+ chain = ifelse(
+ .data$iteration <= midpoint,
+ paste0(.data$chain, "_1"),
+ paste0(.data$chain, "_2")
+ )
+ ) %>%
+ ungroup()
+ }
+
n_iter <- unique(data$n_iterations)
- n_chain <- unique(data$n_chains)
+ n_chain <- length(unique(data$chain))
n_param <- unique(data$n_parameters)
x <- if (is.null(K)) {
@@ -533,7 +585,9 @@ mcmc_rank_ecdf <-
group = .data$chain
)
- scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain))
+ # Update legend title based on split_chains
+ legend_title <- if (split_chains) "Split Chains" else "Chain"
+ scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chain))
facet_call <- NULL
if (n_param == 1) {
diff --git a/man/MCMC-traces.Rd b/man/MCMC-traces.Rd
index 1054591b..4f631067 100644
--- a/man/MCMC-traces.Rd
+++ b/man/MCMC-traces.Rd
@@ -51,7 +51,8 @@ mcmc_rank_overlay(
facet_args = list(),
...,
n_bins = 20,
- ref_line = FALSE
+ ref_line = FALSE,
+ split_chains = FALSE
)
mcmc_rank_hist(
@@ -75,7 +76,8 @@ mcmc_rank_ecdf(
facet_args = list(),
prob = 0.99,
plot_diff = FALSE,
- interpolate_adj = NULL
+ interpolate_adj = NULL,
+ split_chains = FALSE
)
mcmc_trace_data(
@@ -193,6 +195,10 @@ of rank-normalized MCMC samples. Defaults to \code{20}.}
\item{ref_line}{For the rank plots, whether to draw a horizontal line at the
average number of ranks per bin. Defaults to \code{FALSE}.}
+\item{split_chains}{Logical indicating whether to split each chain into two parts.
+If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
+Defaults to \code{FALSE}.}
+
\item{K}{An optional integer defining the number of equally spaced evaluation
points for the PIT-ECDF. Reducing K when using \code{interpolate_adj = FALSE}
makes computing the confidence bands faster. For \code{ppc_pit_ecdf} and
diff --git a/tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg
new file mode 100644
index 00000000..d16c9ca3
--- /dev/null
+++ b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg
@@ -0,0 +1,70 @@
+
+
diff --git a/tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg
new file mode 100644
index 00000000..0079d95e
--- /dev/null
+++ b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg
@@ -0,0 +1,67 @@
+
+
diff --git a/tests/testthat/data-for-mcmc-tests.R b/tests/testthat/data-for-mcmc-tests.R
index 1136cefe..fe892579 100644
--- a/tests/testthat/data-for-mcmc-tests.R
+++ b/tests/testthat/data-for-mcmc-tests.R
@@ -80,4 +80,11 @@ vdiff_dframe_rank_overlay_bins_test <- posterior::as_draws_df(
)
)
+vdiff_dframe_rank_split_chain_test <- posterior::as_draws_df(
+ list(
+ list(theta = -2 + 0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)),
+ list(theta = 1 + -0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5))
+ )
+)
+
set.seed(seed = NULL)
diff --git a/tests/testthat/test-mcmc-traces.R b/tests/testthat/test-mcmc-traces.R
index 62d46c88..f79b4c2e 100644
--- a/tests/testthat/test-mcmc-traces.R
+++ b/tests/testthat/test-mcmc-traces.R
@@ -157,6 +157,10 @@ test_that("mcmc_rank_overlay renders correctly", {
# https://github.com/stan-dev/bayesplot/issues/331
p_not_all_bins_exist <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_bins_test)
+ # https://github.com/stan-dev/bayesplot/issues/333
+ p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_split_chain_test,
+ split_chains = TRUE)
+
vdiffr::expect_doppelganger("mcmc_rank_overlay (default)", p_base)
vdiffr::expect_doppelganger(
"mcmc_rank_overlay (reference line)",
@@ -170,6 +174,9 @@ test_that("mcmc_rank_overlay renders correctly", {
# https://github.com/stan-dev/bayesplot/issues/331
vdiffr::expect_doppelganger("mcmc_rank_overlay (not all bins)", p_not_all_bins_exist)
+
+ # https://github.com/stan-dev/bayesplot/issues/333
+ vdiffr::expect_doppelganger("mcmc_rank_overlay (split chains)", p_split_chains)
})
test_that("mcmc_rank_hist renders correctly", {
@@ -254,6 +261,11 @@ test_that("mcmc_rank_ecdf renders correctly", {
plot_diff = TRUE
)
+ # https://github.com/stan-dev/bayesplot/issues/333
+ p_split_chains <- mcmc_rank_ecdf(vdiff_dframe_rank_split_chain_test,
+ plot_diff = TRUE,
+ split_chains = TRUE)
+
vdiffr::expect_doppelganger("mcmc_rank_ecdf (default)", p_base)
vdiffr::expect_doppelganger("mcmc_rank_ecdf (one parameter)", p_one_param)
vdiffr::expect_doppelganger("mcmc_rank_ecdf (diff)", p_diff)
@@ -261,6 +273,9 @@ test_that("mcmc_rank_ecdf renders correctly", {
"mcmc_rank_ecdf (one param, diff)",
p_diff_one_param
)
+
+ # https://github.com/stan-dev/bayesplot/issues/333
+ vdiffr::expect_doppelganger("mcmc_rank_ecdf (split chain)", p_split_chains)
})
test_that("mcmc_trace with 'np' renders correctly", {