Skip to content

Commit

Permalink
Add control_bayes() and update code and documentation (#477)
Browse files Browse the repository at this point in the history
This implements
https://github.com/insightsengineering/rbmi/blob/main/misc/design_mcmc_improve.qmd.

I would like to note a few minor changes relative to the design:
- The `control` list is just a simple flat list, without an `additional`
element.
- Additional helper function to complete the control list based on
additional arguments inside the draws method
- Print method has been adapted
  • Loading branch information
danielinteractive authored Feb 7, 2025
1 parent 393c50c commit 9a75cd2
Show file tree
Hide file tree
Showing 23 changed files with 545 additions and 125 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export(analyse)
export(ancova)
export(as_class)
export(as_vcov)
export(control_bayes)
export(delta_template)
export(draws)
export(expand)
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
* Fixed bug where `lsmeans(.weights = "proportional_em")` would error if there was only a single categorical variable in the dataset. (#412)
* Removed native pipes `|>` and lambda functions `\(x)` from code base to ensure package is backwards compatible with older versions of R. (#474)

## Breaking Changes

* Deprecated the `burn_in` and `burn_between` arguments in `method_bayes()` in favour of using the `warmup` and `thin` arguments, respectively, in the new `control` list produced by `control_bayes`. This is to align with the `rstan` package.

## New Features

* Added `control_bayes()` function to allow expert users to specify additional control arguments for the MCMC computations using `rstan`.

# rbmi 1.3.1

* Fixed bug where stale caches of the `rstan` model were not being correctly cleared (#459)
Expand Down
120 changes: 120 additions & 0 deletions R/controls.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#' Control the computational details of the imputation methods
#'
#' @description
#'
#' These functions control lower level computational details of the imputation methods.
#'
#' @name control
#'
#' @param warmup a numeric, the number of warmup iterations for the MCMC sampler.
#'
#' @param thin a numeric, the thinning rate of the MCMC sampler.
#'
#' @param chains a numeric, the number of chains to run in parallel.
#'
#' @param init a character string, the method used to initialise the MCMC sampler, see the details.
#'
#' @param seed a numeric, the seed used to initialise the MCMC sampler.
#'
#' @param ... additional arguments to be passed to [rstan::sampling()].
#'
#' @details
#'
#' Currently only the Bayesian imputation via [method_bayes()] uses a control function:
#'
#' - The `init` argument can be set to `"random"` to randomly initialise the sampler with `rstan`
#' default values or to `"mmrm"` to initialise the sampler with the maximum likelihood estimate
#' values of the MMRM.
#' - The `seed` argument is used to set the seed for the MCMC sampler. By default, a random seed
#' is generated, such that outside invocation of the `set.seed()` call can effectively set the
#' seed.
#' - The samples are split across the chains, such that each chain produces `n_samples / chains`
#' (rounded up) samples. The total number of samples that will be returned across all chains is `n_samples`
#' as specified in [method_bayes()].
#' - Therefore, the additional parameters passed to [rstan::sampling()] must not contain
#' `n_samples` or `iter`. Instead, the number of samples must only be provided directly via the
#' `n_samples` argument of [method_bayes()]. Similarly, the `refresh` argument is also not allowed
#' here, instead use the `quiet` argument directly in [draws()].
#'
#' @note For full reproducibility of the imputation results, it is required to use a `set.seed()` call
#' before defining the `control` list, and calling the `draws()` function. It is not sufficient to
#' merely set the `seed` argument in the `control` list.
#'
#' @export
control_bayes <- function(
warmup = 200,
thin = 50,
chains = 1,
init = ifelse(chains > 1, "random", "mmrm"),
seed = sample.int(.Machine$integer.max, 1),
...
) {
additional_pars <- names(list(...))

if (any(c("n_samples", "iter") %in% additional_pars)) {
stop(
"Instead of providing `n_samples` or `iter` here, please specify the",
" number of samples directly via the `n_samples`",
" argument of `method_bayes()`"
)
}
if ("refresh" %in% additional_pars) {
stop(
"Instead of the `refresh` argument here, please provide the `quiet` argument",
" directly to `draws()`"
)
}
list(
warmup = warmup,
thin = thin,
chains = chains,
init = init,
seed = seed,
...
)
}

complete_control_bayes <- function(
control,
n_samples,
quiet,
stan_data,
mmrm_initial
) {
assertthat::assert_that(is.list(control))
control_pars <- names(control)
if ("iter" %in% control_pars) {
stop("`method$control$iter` must not be specified directly, please use `method$n_samples`")
}
assertthat::assert_that(
assertthat::is.number(control$warmup),
assertthat::is.number(control$thin),
assertthat::is.number(control$chains),
assertthat::is.number(n_samples)
)
n_samples_per_chain <- ceiling(n_samples / control$chains)
control$iter <- control$warmup + control$thin * n_samples_per_chain
if ("refresh" %in% control_pars) {
stop("`method$control$refresh` must not be specified directly, please use `quiet`")
}
control$refresh <- ife(
quiet,
0,
ceiling(control$iter / 10)
)
control$init <- ife(
identical(control$init, "mmrm"),
list(list(
theta = as.vector(stan_data$R %*% mmrm_initial$beta),
sigma = mmrm_initial$sigma
)),
control$init
)
if (any(c("object", "data", "pars") %in% control_pars)) {
stop(
"The `object`, `data` and `pars` arguments must not be specified",
" in `method$control`"
)
}
control
}
27 changes: 16 additions & 11 deletions R/draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -586,17 +586,14 @@ print.draws <- function(x, ...) {

method <- x$method

meth_args <- vapply(
mapply(
function(x, y) sprintf(" %s: %s", y, x),
method,
names(method),
USE.NAMES = FALSE,
SIMPLIFY = FALSE
),
identity,
character(1)
)
control_args <- if ("control" %in% names(method)) {
control <- method$control
method <- method[!(names(method) == "control")]
format_method_descriptions(control)
} else {
character()
}
meth_args <- format_method_descriptions(method)

n_samp <- length(x$samples)
n_samp_string <- ife(
Expand All @@ -616,6 +613,14 @@ print.draws <- function(x, ...) {
"Method:",
sprintf(" name: %s", meth),
meth_args,
ife(
length(control_args),
c(
"Controls:",
control_args
),
NULL
),
""
)

Expand Down
69 changes: 35 additions & 34 deletions R/mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,7 @@ fit_mcmc <- function(
method,
quiet = FALSE
) {

n_imputations <- method$n_samples
burn_in <- method$burn_in
burn_between <- method$burn_between
same_cov <- method$same_cov

# fit MMRM (needed for initial values)
# Fit MMRM (needed for Sigma prior parameter and possibly initial values).
mmrm_initial <- fit_mmrm(
designmat = designmat,
outcome = outcome,
Expand All @@ -70,7 +64,7 @@ fit_mcmc <- function(
group = group,
cov_struct = "us",
REML = TRUE,
same_cov = same_cov
same_cov = method$same_cov
)

if (mmrm_initial$failed) {
Expand All @@ -82,35 +76,35 @@ fit_mcmc <- function(
subjid = subjid,
visit = visit,
outcome = outcome,
group = ife(same_cov == TRUE, rep(1, length(group)), group)
group = ife(
isTRUE(method$same_cov),
rep(1, length(group)),
group
)
)

stan_data$Sigma_init <- ife(
same_cov == TRUE,
isTRUE(method$same_cov),
list(mmrm_initial$sigma[[1]]),
mmrm_initial$sigma
)

sampling_args <- list(
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma"),
chains = 1,
warmup = burn_in,
thin = burn_between,
iter = burn_in + burn_between * n_imputations,
init = list(list(
theta = as.vector(stan_data$R %*% mmrm_initial$beta),
sigma = mmrm_initial$sigma
)),
refresh = ife(
quiet,
0,
(burn_in + burn_between * n_imputations) / 10
)
control <- complete_control_bayes(
control = method$control,
n_samples = method$n_samples,
quiet = quiet,
stan_data = stan_data,
mmrm_initial = mmrm_initial
)

sampling_args <- c(
list(
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma")
),
control
)

sampling_args$seed <- sample.int(.Machine$integer.max, 1)

stan_fit <- record({
do.call(rstan::sampling, sampling_args)
Expand All @@ -132,9 +126,9 @@ fit_mcmc <- function(
for (i in warnings_not_allowed) warning(warnings_not_allowed)

fit <- stan_fit$results
check_mcmc(fit, n_imputations)
check_mcmc(fit, method$n_samples)

draws <- extract_draws(fit)
draws <- extract_draws(fit, method$n_samples)

ret_obj <- list(
"samples" = draws,
Expand Down Expand Up @@ -220,17 +214,20 @@ split_dim <- function(a, n) {
#' and then convert the arrays into lists.
#'
#' @param stan_fit A `stanfit` object.
#'
#' @param n_samples Number of MCMC draws.
#'
#' @return
#' A named list of length 2 containing:
#' - `beta`: a list of length equal to the number of draws containing
#' - `beta`: a list of length equal to `n_samples` containing
#' the draws from the posterior distribution of the regression coefficients.
#' - `sigma`: a list of length equal to the number of draws containing
#' - `sigma`: a list of length equal to `n_samples` containing
#' the draws from the posterior distribution of the covariance matrices. Each element
#' of the list is a list with length equal to 1 if `same_cov = TRUE` or equal to the
#' number of groups if `same_cov = FALSE`.
#'
extract_draws <- function(stan_fit) {
extract_draws <- function(stan_fit, n_samples) {
assertthat::assert_that(assertthat::is.number(n_samples))

pars <- rstan::extract(stan_fit, pars = c("beta", "Sigma"))
names(pars) <- c("beta", "sigma")
Expand All @@ -242,9 +239,13 @@ extract_draws <- function(stan_fit) {
pars$sigma,
function(x) split_dim(x, 1)
)
assertthat::assert_that(length(pars$sigma) >= n_samples)
pars$sigma <- pars$sigma[seq_len(n_samples)]

pars$beta <- split_dim(pars$beta, 1)
pars$beta <- lapply(pars$beta, as.vector)
assertthat::assert_that(length(pars$beta) >= n_samples)
pars$beta <- pars$beta[seq_len(n_samples)]

return(pars)
}
Expand Down
37 changes: 18 additions & 19 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
#'
#' @name method
#'
#' @param burn_in a numeric that specifies how many observations should be discarded
#' prior to extracting actual samples. Note that the sampler
#' is initialized at the maximum likelihood estimates and a weakly informative
#' prior is used thus in theory this value should not need to be that high.
#'
#' @param burn_between a numeric that specifies the "thinning" rate i.e. how many
#' observations should be discarded between each sample. This is used to prevent
#' issues associated with autocorrelation between the samples.
#'
#' @param same_cov a logical, if `TRUE` the imputation model will be fitted using a single
#' shared covariance matrix for all observations. If `FALSE` a separate covariance
#' matrix will be fit for each group as determined by the `group` argument of
Expand All @@ -24,6 +15,10 @@
#' @param n_samples a numeric that determines how many imputed datasets are generated.
#' In the case of `method_condmean(type = "jackknife")` this argument
#' must be set to `NULL`. See details.
#'
#' @param control a list which specifies further lower level details of the computations.
#' Currently only used by `method_bayes()`, please see [control_bayes()] for details and
#' default settings.
#'
#' @param B a numeric that determines the number of bootstrap samples for `method_bmlmi`.
#'
Expand All @@ -45,7 +40,9 @@
#' when a conditional mean imputation approach (set via `method_condmean()`) is used.
#' Must be one of `"bootstrap"` or `"jackknife"`.
#'
#' @param seed deprecated. Please use `set.seed()` instead.
#' @param burn_in deprecated. Please use the `warmup` argument in [control_bayes()] instead.
#'
#' @param burn_between deprecated. Please use the `thin` argument in [control_bayes()] instead.
#'
#' @details
#'
Expand All @@ -55,6 +52,9 @@
#' bootstrapped datasets. Likewise, for `method_condmean(type = "jackknife")` there will
#' be `length(unique(data$subjid)) + 1` imputation models and datasets generated. In both cases this is
#' represented by `n + 1` being displayed in the print message.
#' In the case that `method_bayes()` is used, and with the `control` argument the number of chains
#' is set to more than 1, then the `n_samples` samples will be distributed across the chains.
#' The total number of returned samples will still be `n_samples`.
#'
#' The user is able to specify different covariance structures using the the `covariance`
#' argument. Currently supported structures include:
Expand Down Expand Up @@ -94,26 +94,25 @@
#'
#' @export
method_bayes <- function(
burn_in = 200,
burn_between = 50,
same_cov = TRUE,
n_samples = 20,
seed = NULL
control = control_bayes(),
burn_in = NULL,
burn_between = NULL
) {
assertthat::assert_that(
is.null(seed),
is.null(burn_in) && is.null(burn_between),
msg = paste(
"The `seed` argument to `method_bayes()` has been deprecated;",
"please use `set.seed()` instead.",
"The `burn_in` and `burn_between` arguments to `method_bayes()` have been deprecated;",
"please use the `warmup` and `thin` arguments inside `control_bayes()` instead.",
collapse = " "
)
)

x <- list(
burn_in = burn_in,
burn_between = burn_between,
same_cov = same_cov,
n_samples = n_samples
n_samples = n_samples,
control = control
)
return(as_class(x, c("method", "bayes")))
}
Expand Down
Loading

0 comments on commit 9a75cd2

Please sign in to comment.