Skip to content

Commit

Permalink
Ensure consistent style throughout codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
venpopov committed Jan 20, 2025
1 parent 7c8c33c commit bf1bf75
Show file tree
Hide file tree
Showing 30 changed files with 937 additions and 766 deletions.
52 changes: 33 additions & 19 deletions R/bmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,35 @@
#' ff <- bmmformula(c ~ 1, kappa ~ 1)
#'
#' # fit the model
#' fit <- bmm(formula = ff,
#' data = dat,
#' model = sdm(resp_error = "y"),
#' cores = 4,
#' backend = 'cmdstanr')
#' fit <- bmm(
#' formula = ff,
#' data = dat,
#' model = sdm(resp_error = "y"),
#' cores = 4,
#' backend = "cmdstanr"
#' )
bmm <- function(formula, data, model,
prior = NULL,
sort_data = getOption('bmm.sort_data', "check"),
silent = getOption('bmm.silent', 1),
backend = getOption('brms.backend', NULL),
sort_data = getOption("bmm.sort_data", "check"),
silent = getOption("bmm.silent", 1),
backend = getOption("brms.backend", NULL),
file = NULL, file_compress = TRUE,
file_refit = getOption('bmm.file_refit', FALSE), ...) {
file_refit = getOption("bmm.file_refit", FALSE), ...) {
deprecated_args(...)
dots <- list(...)

# check if the model has been previously fit and return it if requested
x <- read_bmmfit(file, file_refit)
if (!is.null(x)) return(x)
if (!is.null(x)) {
return(x)
}

# set temporary global options and return modified arguments for brms
configure_opts <- nlist(sort_data, silent, backend, parallel = dots$parallel,
cores = dots$cores)
configure_opts <- nlist(
sort_data, silent, backend,
parallel = dots$parallel,
cores = dots$cores
)
opts <- configure_options(configure_opts)
dots$parallel <- NULL

Expand All @@ -134,8 +142,12 @@ bmm <- function(formula, data, model,
fit <- call_brm(fit_args)

# model post-processing
fit <- postprocess_brm(model, fit, fit_args = fit_args, user_formula = user_formula,
configure_opts = configure_opts)
fit <- postprocess_brm(
model, fit,
fit_args = fit_args,
user_formula = user_formula,
configure_opts = configure_opts
)

# save the fitted model object if !is.null
save_bmmfit(fit, file, compress = file_compress)
Expand All @@ -147,11 +159,13 @@ bmm <- function(formula, data, model,
#' @export
fit_model <- function(formula, data, model,
prior = NULL,
sort_data = getOption('bmm.sort_data', "check"),
silent = getOption('bmm.silent', 1),
backend = getOption('brms.backend', NULL),
sort_data = getOption("bmm.sort_data", "check"),
silent = getOption("bmm.silent", 1),
backend = getOption("brms.backend", NULL),
...) {
message("You are using the deprecated `fit_model()` function. Please use `bmm()` instead.")
bmm(formula = formula, data = data, model = model, prior = prior,
sort_data = sort_data, silent = silent, backend = backend, ...)
bmm(
formula = formula, data = data, model = model, prior = prior,
sort_data = sort_data, silent = silent, backend = backend, ...
)
}
71 changes: 40 additions & 31 deletions R/bmmformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
#' kappa ~ 0 + set_size + (0 + set_size | id)
#' )
#' identical(imm_formula, imm_formula2)
bmmformula <- function(...){
bmmformula <- function(...) {
dots <- list(...)
formula <- list()
for (i in seq_along(dots)) {
Expand Down Expand Up @@ -104,21 +104,25 @@ bmf <- function(...) {

# method for adding formulas to a bmmformula
#' @export
"+.bmmformula" <- function(f1,f2) {
"+.bmmformula" <- function(f1, f2) {
stopif(!is_bmmformula(f1), "The first argument must be a bmmformula.")

if (is_formula(f2)) {
par2 <- all.vars(f2)[1]
if (par2 %in% names(f1)) {
message2(paste("The parameter", par2, "is already part of the formula.",
"Overwriting the initial formula."))
message2(
"The parameter {par2} is already part of the formula.
Overwriting the initial formula."
)
}
f1[[par2]] <- f2
} else if (is_bmmformula(f2)) {
for (par2 in names(f2)) {
if (par2 %in% names(f1)) {
message2(paste("The parameter", par2, "is already part of the formula.",
"Overwriting the initial formula."))
message2(
"The parameter is already part of the formula.
Overwriting the initial formula."
)
}
f1[[par2]] <- f2[[par2]]
}
Expand Down Expand Up @@ -169,18 +173,22 @@ bmf <- function(...) {
#' @return the formula object
#' @keywords internal developer
check_formula <- function(model, data, formula) {
UseMethod('check_formula')
UseMethod("check_formula")
}

#' @export
check_formula.bmmodel <- function(model, data, formula) {
stopif(is_brmsformula(formula),
"The provided formula is a brms formula. Please use the bmf() function. E.g.:
bmmformula(kappa ~ 1, thetat ~ 1) or bmf(kappa ~ 1, thetat ~ 1)")

stopif(!is_bmmformula(formula),
"The provided formula is not a bmm formula. Please use the bmf() function. E.g.:
bmmformula(kappa ~ 1, thetat ~ 1) or bmf(kappa ~ 1, thetat ~ 1)")
stopif(
is_brmsformula(formula),
"The provided formula is a brms formula. Please use the bmf() function. E.g.:
bmmformula(kappa ~ 1, thetat ~ 1) or bmf(kappa ~ 1, thetat ~ 1)"
)

stopif(
!is_bmmformula(formula),
"The provided formula is not a bmm formula. Please use the bmf() function. E.g.:
bmmformula(kappa ~ 1, thetat ~ 1) or bmf(kappa ~ 1, thetat ~ 1)"
)

wpar <- wrong_parameters(model, formula)
stopif(length(wpar), "Unrecognized model parameters: {collapse_comma(wpar)}")
Expand All @@ -201,11 +209,13 @@ check_formula.non_targets <- function(model, data, formula) {
has_set_size <- sapply(pred_list, function(x) set_size_var %in% x)
ss_forms <- formula[has_set_size]
intercepts <- sapply(ss_forms, has_intercept)
stopif(any(intercepts),
"The formula for parameter(s) {names(ss_forms)[intercepts]} contains \\
an intercept and also uses set_size as a predictor. This model requires \\
that the intercept is supressed when set_size is used as predictor. \\
Try using 0 + {set_size_var} instead.")
stopif(
any(intercepts),
"The formula for parameter(s) {names(ss_forms)[intercepts]} contains \\
an intercept and also uses set_size as a predictor. This model requires \\
that the intercept is supressed when set_size is used as predictor. \\
Try using 0 + {set_size_var} instead."
)
NextMethod("check_formula")
}

Expand All @@ -220,14 +230,14 @@ check_formula.non_targets <- function(model, data, formula) {
#' formulas for the specified `bmmodel`
#' @keywords internal developer
#' @examples
#' model <- mixture2p(resp_error = "error")
#' model <- mixture2p(resp_error = "error")
#'
#' formula <- bmmformula(
#' thetat ~ 0 + set_size + (0 + set_size | id),
#' kappa ~ 1 + (1 | id)
#' )
#' formula <- bmmformula(
#' thetat ~ 0 + set_size + (0 + set_size | id),
#' kappa ~ 1 + (1 | id)
#' )
#'
#' brms_formula <- bmf2bf(model, formula)
#' brms_formula <- bmf2bf(model, formula)
#' @export
bmf2bf <- function(model, formula) {
UseMethod("bmf2bf")
Expand All @@ -253,7 +263,7 @@ bmf2bf.bmmodel <- function(model, formula) {

# paste first line of the brms formula for all bmmodels with 1 response variable
#' @export
bmf2bf.default <- function(model, formula){
bmf2bf.default <- function(model, formula) {
# set base brms formula based on response
brms::bf(paste0(model$resp_vars[[1]], "~ 1"))
}
Expand All @@ -266,20 +276,19 @@ add_missing_parameters <- function(model, formula, replace_fixed = TRUE) {
if (replace_fixed) {
formula_pars <- formula_pars[!formula_pars %in% fixed_pars]
}
missing_pars <- setdiff(model_pars,formula_pars)
missing_pars <- setdiff(model_pars, formula_pars)
is_fixed <- missing_pars %in% fixed_pars
names(is_fixed) <- missing_pars
for (mpar in missing_pars) {
add <- stats::as.formula(paste(mpar,"~ 1"))
add <- stats::as.formula(paste(mpar, "~ 1"))
if (is_fixed[mpar]) {
attr(add, "constant") <- TRUE
} else {
message2("No formula for parameter {mpar} provided. Only a fixed \\
intercept will be estimated.")
message2("No formula for parameter {mpar} provided. Only a fixed intercept will be estimated.")
}
formula[mpar] <- list(add)
}
all_pars <- unique(c(model_pars,formula_pars))
all_pars <- unique(c(model_pars, formula_pars))
formula[all_pars] # reorder formula to match model parameters order
}

Expand Down
57 changes: 38 additions & 19 deletions R/brms-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ nlist <- function(...) {
dots <- list(...)
no_names <- is.null(names(dots))
has_name <- if (no_names) FALSE else nzchar(names(dots))
if (all(has_name)) return(dots)
if (all(has_name)) {
return(dots)
}
nms <- as.character(m)[-1]
if (no_names) {
names(dots) <- nms
Expand Down Expand Up @@ -172,8 +174,9 @@ has_intercept <- function(formula) {
# @param pattern regex that must be matches by the object names
# @return a character vector of object names
lsp <- function(package, what = "all", pattern = ".*") {
if (!is.character(substitute(package)))
if (!is.character(substitute(package))) {
package <- deparse0(substitute(package))
}
ns <- asNamespace(package)

## base package does not have NAMESPACE
Expand All @@ -182,30 +185,44 @@ lsp <- function(package, what = "all", pattern = ".*") {
return(res[grep(pattern, res, perl = TRUE, ignore.case = TRUE)])
} else {
## for non base packages
if (exists('.__NAMESPACE__.', envir = ns, inherits = FALSE)) {
wh <- get('.__NAMESPACE__.', inherits = FALSE,
envir = asNamespace(package, base.OK = FALSE))
what <- if (missing(what)) 'all'
else if ('?' %in% what) return(ls(wh))
else ls(wh)[pmatch(what[1], ls(wh))]
if (!is.null(what) && !any(what %in% c('all', ls(wh))))
stop('\'what\' should be one of ',
paste0(shQuote(ls(wh)), collapse = ', '),
', or \'all\'', domain = NA)
if (exists(".__NAMESPACE__.", envir = ns, inherits = FALSE)) {
wh <- get(".__NAMESPACE__.",
inherits = FALSE,
envir = asNamespace(package, base.OK = FALSE)
)
what <- if (missing(what)) {
"all"
} else if ("?" %in% what) {
return(ls(wh))
} else {
ls(wh)[pmatch(what[1], ls(wh))]
}
if (!is.null(what) && !any(what %in% c("all", ls(wh)))) {
stop("'what' should be one of ",
paste0(shQuote(ls(wh)), collapse = ", "),
", or 'all'",
domain = NA
)
}
res <- sapply(ls(wh), function(x) getNamespaceInfo(ns, x))
res <- rapply(res, ls, classes = 'environment',
how = 'replace', all.names = TRUE)
if (is.null(what))
res <- rapply(res, ls,
classes = "environment",
how = "replace", all.names = TRUE
)
if (is.null(what)) {
return(res[grep(pattern, res, perl = TRUE, ignore.case = TRUE)])
if (what %in% 'all') {
}
if (what %in% "all") {
res <- ls(getNamespace(package), all.names = TRUE)
return(res[grep(pattern, res, perl = TRUE, ignore.case = TRUE)])
}
if (any(what %in% ls(wh))) {
res <- res[[what]]
return(res[grep(pattern, res, perl = TRUE, ignore.case = TRUE)])
}
} else stop(sprintf('no NAMESPACE file found for package %s', package))
} else {
stop(sprintf("no NAMESPACE file found for package %s", package))
}
}
}

Expand Down Expand Up @@ -281,8 +298,10 @@ rename <- function(x, pattern = NULL, replacement = NULL,
dup <- duplicated(out)
if (check_dup && any(dup)) {
dup <- x[out %in% out[dup]]
stop2("Internal renaming led to duplicated names. \n",
"Occured for: ", collapse_comma(dup))
stop2(
"Internal renaming led to duplicated names. \n",
"Occured for: ", collapse_comma(dup)
)
}
out
}
Loading

0 comments on commit bf1bf75

Please sign in to comment.