diff --git a/NAMESPACE b/NAMESPACE index 6efe6d01..7fe559c6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,16 +2,31 @@ S3method("[",neff_ratio) S3method("[",rhat) +S3method(apply_transformations,array) +S3method(apply_transformations,matrix) +S3method(diagnostic_factor,neff_ratio) +S3method(diagnostic_factor,rhat) S3method(log_posterior,CmdStanMCMC) S3method(log_posterior,stanfit) S3method(log_posterior,stanreg) +S3method(melt_mcmc,matrix) +S3method(melt_mcmc,mcmc_array) S3method(neff_ratio,CmdStanMCMC) S3method(neff_ratio,stanfit) S3method(neff_ratio,stanreg) +S3method(num_chains,data.frame) +S3method(num_chains,mcmc_array) +S3method(num_iters,data.frame) +S3method(num_iters,mcmc_array) +S3method(num_params,data.frame) +S3method(num_params,mcmc_array) S3method(nuts_params,CmdStanMCMC) S3method(nuts_params,list) S3method(nuts_params,stanfit) S3method(nuts_params,stanreg) +S3method(parameter_names,array) +S3method(parameter_names,default) +S3method(parameter_names,matrix) S3method(plot,bayesplot_grid) S3method(plot,bayesplot_scheme) S3method(pp_check,default) diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index 0111c2f5..41e2c4ee 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -124,6 +124,8 @@ select_parameters <- #' @return A molten data frame. #' melt_mcmc <- function(x, ...) UseMethod("melt_mcmc") + +#' @export melt_mcmc.mcmc_array <- function(x, varnames = c("Iteration", "Chain", "Parameter"), @@ -144,6 +146,7 @@ melt_mcmc.mcmc_array <- function(x, } # If all chains are already merged +#' @export melt_mcmc.matrix <- function(x, varnames = c("Draw", "Parameter"), value.name = "Value", @@ -305,13 +308,17 @@ chain_list2array <- function(x) { # Get parameter names from a 3-D array parameter_names <- function(x) UseMethod("parameter_names") + +#' @export parameter_names.array <- function(x) { stopifnot(is_3d_array(x)) dimnames(x)[[3]] %||% abort("No parameter names found.") } +#' @export parameter_names.default <- function(x) { colnames(x) %||% abort("No parameter names found.") } +#' @export parameter_names.matrix <- function(x) { colnames(x) %||% abort("No parameter names found.") } @@ -391,6 +398,8 @@ validate_transformations <- apply_transformations <- function(x, ...) { UseMethod("apply_transformations") } + +#' @export apply_transformations.matrix <- function(x, ..., transformations = list()) { pars <- colnames(x) x_transforms <- validate_transformations(transformations, pars) @@ -400,6 +409,8 @@ apply_transformations.matrix <- function(x, ..., transformations = list()) { x } + +#' @export apply_transformations.array <- function(x, ..., transformations = list()) { stopifnot(length(dim(x)) == 3) pars <- dimnames(x)[[3]] @@ -437,17 +448,23 @@ num_chains <- function(x, ...) UseMethod("num_chains") num_iters <- function(x, ...) UseMethod("num_iters") num_params <- function(x, ...) UseMethod("num_params") +#' @export num_params.mcmc_array <- function(x, ...) dim(x)[3] +#' @export num_chains.mcmc_array <- function(x, ...) dim(x)[2] +#' @export num_iters.mcmc_array <- function(x, ...) dim(x)[1] +#' @export num_params.data.frame <- function(x, ...) { stopifnot("Parameter" %in% colnames(x)) length(unique(x$Parameter)) } +#' @export num_chains.data.frame <- function(x, ...) { stopifnot("Chain" %in% colnames(x)) length(unique(x$Chain)) } +#' @export num_iters.data.frame <- function(x, ...) { cols <- colnames(x) stopifnot("Iteration" %in% cols || "Draws" %in% cols) diff --git a/R/mcmc-diagnostics.R b/R/mcmc-diagnostics.R index 4e5ec767..a56be0bd 100644 --- a/R/mcmc-diagnostics.R +++ b/R/mcmc-diagnostics.R @@ -364,12 +364,14 @@ diagnostic_factor <- function(x, ...) { UseMethod("diagnostic_factor") } +#' @export diagnostic_factor.rhat <- function(x, ..., breaks = c(1.05, 1.1)) { cut(x, breaks = c(-Inf, breaks, Inf), labels = c("low", "ok", "high"), ordered_result = FALSE) } +#' @export diagnostic_factor.neff_ratio <- function(x, ..., breaks = c(0.1, 0.5)) { cut(x, breaks = c(-Inf, breaks, Inf), labels = c("low", "ok", "high"), diff --git a/R/ppc-discrete.R b/R/ppc-discrete.R index aff81251..89f6ee00 100644 --- a/R/ppc-discrete.R +++ b/R/ppc-discrete.R @@ -367,7 +367,7 @@ ppc_bars_data <- #' @param y,yrep,group User's already validated `y`, `yrep`, and (if applicable) #' `group` arguments. #' @param prob,freq User's `prob` and `freq` arguments. -#' @importFrom dplyr "%>%" ungroup count arrange mutate summarise across full_join rename all_of +#' @importFrom dplyr %>% ungroup count arrange mutate summarise across full_join rename all_of .ppc_bars_data <- function(y, yrep, group = NULL, prob = 0.9, freq = TRUE) { alpha <- (1 - prob) / 2 probs <- sort(c(alpha, 0.5, 1 - alpha))