Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add design doc to align on MCMC options enhancements #472

Merged
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 267 additions & 0 deletions misc/design_mcmc_improve.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Design doc for improving MCMC sampling

This design doc proposes several updates to the `rbmi` package to improve the options for MCMC sampling. The main goal is to provide more flexibility to the user in terms of the sampling behavior, while still keeping the package easy to use.

## Wish list

1. Allow more detailed Stan sampling arguments as input parameters (`rstan::sampling` control should be controllable)
1. Keep current option (single chain; MMRM output as initial value as default)
1. Add option to have multiple chains with random initial values
1. Always have default prior and don't make prior the input parameter.

## Sampling controls

Currently the `rstan` sampling details are controlled by a hard coded list in [`fit_mcmc`](https://github.com/insightsengineering/rbmi/blob/main/R/mcmc.R#L94). A few of the parameters are exposed in `method_bayes()` to the user.

The `sampling()` function is described [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html). In particular, the `control` argument allows to further specify the sampling behavior, which is documented [here](https://mc-stan.org/rstan/reference/stan.html).

Idea:

- Provide a `control_bayes()` function that returns a list of control parameters for `rstan::stan()`
- This function should have a default behavior that is the same as the current behavior
- The user can override the default behavior by providing a list of control parameters
- Takes additional arguments which are then included in the returned list
- In order to simplify the user interface, the previous `method` parts `burn_in` and `burn_between`, are moved to the `control_bayes()` function, and will properly be deprecated from `method_bayes()`.
- In order to keep the pattern with the other `method_*()` functions, the `n_samples` argument remains in the `method_bayes()` function. The `method_bayes()` function gains an additional argument `control` where the `control_bayes()` function result is passed.
- Arguments of the new control function will be briefly documented (not just copied from the `rstan` documentation, but rather explained in the context of the `rbmi` package). The user will be referred to the `rstan` documentation for more details for additionally passed arguments.
- The `draws()` method keeps its current signature, but internally sets the `iter` and `refresh` arguments in the `method$control` list, based on the `method$n_samples` and the `quiet` arguments.
- The `fit_mcmc` function processes the additional `control` output instead of the currently hard coded settings, and passes it internally to `rstan::sampling()`

License considerations:

- The `rstan` package is licensed under GPL-3.0, while `rbmi` is licensed under the MIT license.
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
- Only using the `rstan::sampling()` function and specifying reasonable default values for some of its control arguments should not require to change the license of the `rbmi` package. This is because this does not constitute a derivative work. Instead, the `rbmi` package would continue to be "linking" to the GPL package.
- Importantly, the `rbmi` package only has the `rstan` package in the `Suggests` field of the `DESCRIPTION` file, not in the `Imports` field. This means that the `rstan` package is not required to be installed to use the `rbmi` package.

## Multiple chains

Currently just a single chain is used.

Idea:

- As part of the `control_bayes()` function, we could add an argument to specify the number of chains.
- If the number of chains is greater than 1, then the `init` argument of `rstan::sampling()` should be set to `random` (see [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html)).
- We can make some experiments to ensure that chains converge to the stationary distribution within reasonable time frame if started from random values. If the Stan random initial values are not good enough, we could later have the option to use initial values sampled from the MLE distribution as the initial values for the chains.
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved

We note that due to the definition of the prior distribution for the covariance matrix, we still need to fit the frequentist MMRM model in any case to get the required prior parameter.

- *Question*: What do we need to adapt downstream when using multiple chains?
- *Answer*: We will be able to test this in detail and address it, once we have implemented the possibility to use multiple chains in the development version. To avoid possible confusion, we can add a warning in the development version, which informs the user that multiple chains are not yet fully supported, while we are still working on it.

## Default prior

We would want to keep using a default prior, as it is currently implemented for the unstructured covariance structure already. The user should never have to specify the priors actively, because there should always be a reasonable default prior implemented by `rbmi`.

That implies that we will need to come up with such default priors for the other covariance structures. This will be handled in a separate, forthcoming design doc.

## Prototype

### Control function

This is how the control function could look like:

```r
control_bayes <- function(
burn_in = 200,
burn_between = 50,
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
chains = 1,
init = ife(chains > 1, "random", "mmrm"),
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
seed = sample.int(.Machine$integer.max, 1),
n_samples = NULL,
iter = NULL,
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
refresh = NULL,
...
) {
if (!is.null(n_samples) || !is.null(iter)) {
stop("Please provide the number of samples directly via the `n_samples` argument of `method_bayes()`")
}
is (!is.null(refresh)) {
stop("Instead of the `refresh` argument here, please provide the `quiet` argument directly to `draws()`")
}
list(
warmup = burn_in,
thin = burn_between,
chains = chains,
refresh = refresh,
init = init,
seed = seed,
additional = list(...)
)
}
```

The rationale for using this control function approach is:

- We provide sensible defaults to the user, however they can now override them if they want to.
- The user can now also pass additional arguments to the `rstan::sampling()` function, but they are not required to do so.
- If the `rstan` API changes, we can adapt the `control_bayes()` function accordingly, and the user does not have to change their code. We would be notified by CRAN reverse dependencies checks if the `rstan` API changes because we will include tests for the actual use of the `rstan` package in the `rbmi` package.
- We have checked the stability of the `rstan` API over the past, by having a look at the GitHub repository ([link](https://github.com/stan-dev/rstan/blob/develop/rstan/rstan/R/stanmodel-class.R#L507)). It seems stable over the last 6 years:
- Via the [blame view](https://github.com/stan-dev/rstan/blame/develop/rstan/rstan/R/stanmodel-class.R#L507) we can see that the signature of the `sampling` method has not changed during the last 6 years. The signature explicitly includes all of our arguments.
- The `refresh` argument is not directly accessible, because we want to keep the `quiet` interface of the `draws()` method. Internally, it will still be used (as now as well). The `refresh` argument is passed via `...`. We can see that the arguments provided via `...` are checked by `is_arg_deprecated()`, which only lists an argument `enable_random_init` as deprecated, and this has not changed in the last 9 years. Further we can see that in the `stan()` function's [blame view](https://github.com/stan-dev/rstan/blame/develop/rstan/rstan/R/rstan.R#L280) the `refresh` argument has been recognized for 11 years.
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved

### Method function

This is how the method function could look like:

```r
method_bayes <- function(
same_cov = TRUE,
n_samples = 500,
control = control_bayes(),
burn_in = NULL,
burn_between = NULL
) {
assertthat::assert_that(
is.null(burn_in), msg = paste("The `burn_in` argument to `method_bayes()` has been deprecated;",
"please specify this inside `control = control_bayes()` instead.", collapse = " "))
assertthat::assert_that(
is.null(burn_between), msg = paste("The `burn_between` argument to `method_bayes()` has been deprecated;",
"please specify this inside `control = control_bayes()` instead.", collapse = " "))
x <- list(same_cov = same_cov, n_samples = n_samples, control = control)
return(as_class(x, c("method", "bayes")))
}
```

- The `seed` argument has been deprecated in a previous release and can therefore be removed here.
- We deprecate here the `burn_in` and `burn_between` arguments, and refer the user to the `control_bayes()` function instead. We will remove these arguments in a future release.

### Fit function

This function's output `control` inside the `method` list can then be used in `fit_mcmc()` like this (copied and adapted function code):

```r
fit_mcmc <- function(
designmat,
outcome,
group,
subjid,
visit,
method,
quiet = FALSE) {

# Fit MMRM (needed for Sigma prior parameter and possibly initial values).
mmrm_initial <- fit_mmrm(
designmat = designmat,
outcome = outcome,
subjid = subjid,
visit = visit,
group = group,
cov_struct = "us",
REML = TRUE,
same_cov = method$same_cov
)

if (mmrm_initial$failed) {
stop("Fitting MMRM to original dataset failed")
}

stan_data <- prepare_stan_data(
ddat = designmat,
subjid = subjid,
visit = visit,
outcome = outcome,
group = ife(method$same_cov == TRUE, rep(1, length(group)), group)
)

stan_data$Sigma_init <- ife(
same_cov == TRUE,
list(mmrm_initial$sigma[[1]]),
mmrm_initial$sigma
)

control <- method$control
sampling_args <- c(
list(
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma"),
chains = control$chains,
warmup = control$warmup,
thin = control$thin,
iter = control$iter,
init = ife(
control$init == "mmrm",
list(list(
theta = as.vector(stan_data$R %*% mmrm_initial$beta),
sigma = mmrm_initial$sigma
)),
control$init
),
refresh = control$refresh,
seed = control$seed
),
control$additional
)

stan_fit <- record({
do.call(rstan::sampling, sampling_args)
})

if (!is.null(stan_fit$errors)) {
stop(stan_fit$errors)
}

ignorable_warnings <- c(
"Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#bulk-ess",
"Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#tail-ess"
)

# handle warning: display only warnings if
# 1) the warning is not in ignorable_warnings
warnings <- stan_fit$warnings
warnings_not_allowed <- warnings[!warnings %in% ignorable_warnings]
for (i in warnings_not_allowed) warning(warnings_not_allowed)

fit <- stan_fit$results
check_mcmc(fit, n_imputations)

draws <- extract_draws(fit)

ret_obj <- list(
"samples" = draws,
"fit" = fit
)

return(ret_obj)
}
```

### Draws method

Finally, the user can call the `draws()` method as before. The only differences are:

- The `method` argument called by `method_bayes()` now also comprises the `control` argument.
- Internally in the `draws()` method, additional elements are inserted in the `control` list:

```r
method$control$iter <- method$control$warmup + method$control$thin * method$n_samples
method$control$refresh <- ife(
quiet,
0,
method$control$iter / 10
)
```

This is to keep the pattern of having the `quiet` argument of the `draws` method. See above the corresponding checks we make in the `control_bayes()` function call to avoid problems if the user passes `refresh`, `iter` or `n_samples` to the `control_bayes()` function.

Note that it is important for easy usability that we just rely on all default arguments in the `control_bayes()` function. The user can then just override the defaults they need to override by
specifying the corresponding arguments. In addition, we can still pass the Stan `control` argument inside `control_bayes()`, and it is not a problem that this argument has the same name as the `draws()` argument.

So a call could e.g. just be:

```r
draws(
data = data,
vars = c("beta", "Sigma"),
method = method_bayes(
n_samples = 500,
control = control_bayes(
chains = 4,
burn_in = 1000,
burn_between = 100,
seed = 123
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
)
),
ncores = 4
)
```
Loading