Skip to content

Commit

Permalink
Refactor m3 distributions (#49)
Browse files Browse the repository at this point in the history
- rename act_funs_m3versions to construct_m3_act_funs
- cleanup documentation for m3 distribution functions
- add examples and tests for m3 distribution functions
- the logic inside dm3 and rm3 was exactly the same and the only thing that differed was the call to either dmultinom or - rmultinom. So I abstracted the contrustion of the probability vector into a new shared function - compute_m3_probability_vector(). That function is also much simpler than the body of dm3 before, but it achieves the same output
- move some of the checks that were performed inside dm3 to the m3() call where they should happen instead.
  • Loading branch information
venpopov authored Jan 27, 2025
1 parent 9914679 commit fb527a5
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 171 deletions.
172 changes: 59 additions & 113 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -574,134 +574,80 @@ rimm <- function(n, mu = c(0, 2, -1.5), dist = c(0, 0.5, 2),

#' @title Distribution functions for the Memory Measurement Model (M3)
#'
#' @description Density and random generation functions for the
#' memory measurement model. Please note that these functions are currently not vectorized.
#' @description Density and random generation functions for the memory
#' measurement model. Please note that these functions are currently not
#' vectorized.
#'
#' @name m3dist
#'
#' @param x Numeric vector of observed responses
#' @param n Numeric. Number of observations to generate data for
#' @param size Number of trials per observation
#' @param x Integer vector of length `K` where K is the number of response categories
#' and each value is the number of observed responses per category
#' @param n Integer. Number of observations to generate data for
#' @param size The total number of observations in all categories
#' @param pars A named vector of parameters of the memory measurement model
#' @param m3_model A `bmmodel` object specifying the m3 model that densities or random samples should be generated for
#' @param act_funs A `bmmformula` object specifying the activation functions for the different response categories for the "custom" version of the M3.
#' @param log Logical; if `TRUE`, values are returned on the log scale.
#' @param ... can be used to pass additional variables that are used in the activation functions,
#' but not parameters of the model
#' @param m3_model A `bmmodel` object specifying the m3 model that densities or
#' random samples should be generated for
#' @param act_funs A `bmmformula` object specifying the activation functions for
#' the different response categories for the "custom" version of the M3. The
#' default will attempt to construct the standard activation functions for the
#' "ss" and "cs" model version. For a custom m3 model you need to specify the
#' act_funs argument manually
#' @param log Logical; if `TRUE` (default), values are returned on the log scale.
#' @param ... can be used to pass additional variables that are used in the
#' activation functions, but not parameters of the model
#'
#' @keywords distribution
#'
#' @references Oberauer, K., & Lewandowsky, S. (2019). Simple measurement models for complex working-memory tasks.
#' Psychological Review, 126(6), 880–932. https://doi.org/10.1037/rev0000159
#' @references Oberauer, K., & Lewandowsky, S. (2019). Simple measurement models
#' for complex working-memory tasks. Psychological Review, 126(6), 880–932.
#' https://doi.org/10.1037/rev0000159
#'
#' @return `dm3` gives the density of the memory measurement model,
#' and `rm3` gives the random generation function for the
#' memory measurement model.
#' @return `dm3` gives the density of the memory measurement model, and `rm3`
#' gives the random generation function for the memory measurement model.
#'
#' @examples
#' model <- m3(
#' resp_cats = c("corr", "other", "npl"),
#' num_options = c(1, 4, 5),
#' choice_rule = "simple",
#' version = "ss"
#' )
#' dm3(x = c(20, 10, 10), pars = c(a = 1, b = 1, c = 2), m3_model = model)
#' @export
#'
dm3 <- function(x, pars, m3_model, act_funs = NULL, log = TRUE, ...) {
# unpack additional arguments
dots <- list(...)
if (length(dots) != 0) pars <- c(pars, unlist(dots))

stopif(is.null(names(pars)),
glue("Unnamed vectors passed to \"pars\".\n",
"Please name the \"pars\" vector with the parameter names used in the activation functions."))

stopif(!m3_model$version %in% c("ss","cs","custom"),
glue("Unsupported Version: \"", version,"\" \n",
"Please choose one of the following options for the version argument:\n",
"\"ss\",\"cs\", or \"custom\""))

if (is.null(act_funs)) act_funs <- act_funs_m3versions(m3_model, warnings = FALSE)
if (is.null(act_funs)) {
stop2(glue("No activation functions for version \"custom\" provided.\n",
"Please pass activation functions for the different response categories\n",
"using the \"act_funs\" argument."))
}

if (!length(rhs_vars(act_funs)) == length(pars)) {
stop2(glue("The number of parameters used in the activation functions \n",
"mismatches the number of parameters (\"pars\") and additional arguments (i.e. ...) passed to the function."))
} else if (!all(rhs_vars(act_funs) %in% names(pars))) {
stop2(glue("Some parameters used in the activation functions are not specified in the \"pars\" argument."))
}

resp_cats <- m3_model$resp_var$resp_cats
acts <- numeric()
for (i in resp_cats) {
act_fun <- act_funs[[i]]
# extract right-side formula
right_from = as.character(act_fun[-2]) %>% stringr::str_remove("~")
acts[i] <- with(data.frame(rbind(pars)), eval(parse(text = right_from[2])))
}
acts

num_options <- m3_model$other_vars$num_options
if (tolower(m3_model$other_vars$choice_rule) == "simple") {
probs <- (acts*num_options)/sum(acts*num_options)
} else if (tolower(m3_model$other_vars$choice_rule) == "softmax") {
probs <- (exp(acts)*num_options)/sum(exp(acts)*num_options)
} else {
stop2(glue("Unsupported choice rule: \" ", m3_model$other_vars$choice_rule,"\"\n",
"Please select either \"simple\" or \"softmax\" as choice_rule."))
}

density <- dmultinom(x, size = sum(x), prob = probs, log = TRUE)
if (!log) return(exp(density))
density
dm3 <- function(x, pars, m3_model, act_funs = construct_m3_act_funs(m3_model, warnings = FALSE),
log = TRUE, ...) {
probs <- .compute_m3_probability_vector(pars, m3_model, act_funs, ...)
dmultinom(x, prob = probs, log = log)
}

#' @rdname m3dist
#' @export
rm3 <- function (n, size, pars, m3_model, act_funs = NULL, ...) {
# unpack additional arguments
dots <- list(...)
if (length(dots) != 0) pars <- c(pars, unlist(dots))

stopif(is.null(names(pars)),
glue("Unnamed vectors passed to \"pars\".\n",
"Please name the \"pars\" vector with the parameter names used in the activation functions."))

stopif(!m3_model$version %in% c("ss","cs","custom"),
glue("Unsupported Version: \"", version,"\" \n",
"Please choose one of the following options for the version argument:\n",
"\"ss\",\"cs\", or \"custom\""))

if (is.null(act_funs)) act_funs <- act_funs_m3versions(m3_model, warnings = FALSE)
if (is.null(act_funs)) {
stop2(glue("No activation functions for version \"custom\" provided.\n",
"Please pass activation functions for the different response categories\n",
"using the \"act_funs\" argument."))
}
rm3 <- function(n, size, pars, m3_model, act_funs = construct_m3_act_funs(m3_model, warnings = FALSE),
...) {
probs <- .compute_m3_probability_vector(pars, m3_model, act_funs, ...)
t(rmultinom(n, size = size, prob = probs))
}

if (!length(rhs_vars(act_funs)) == length(pars)) {
stop2(glue("The number of parameters used in the activation functions \n",
"mismatches the number of parameters (\"pars\") and additional arguments (i.e. ...) passed to the function."))
} else if (!all(rhs_vars(act_funs) %in% names(pars))) {
stop2(glue("Some parameters used in the activation functions are not specified in the \"pars\" argument."))
}
.compute_m3_probability_vector <-
function(pars, m3_model, act_funs = construct_m3_act_funs(m3_model, warnings = FALSE), ...) {
pars <- c(pars, unlist(list(...)))
stopif(
is_try_error(try(act_funs, silent = TRUE)),
'No activation functions for version "custom" provided.
Please pass activation functions for the different response categories
using the "act_funs" argument.'
)
stopif(
!identical(sort(rhs_vars(act_funs)), sort(names(pars))),
'The names or number of parameters used in the activation functions mismatch the names or number
of parameters ("pars") and additional arguments (i.e. ...) passed to the function.'
)

resp_cats <- m3_model$resp_var$resp_cats
acts <- numeric()
for (i in resp_cats) {
act_fun <- act_funs[[i]]
# extract right-side formula
right_from = as.character(act_fun[-2]) %>% stringr::str_remove("~")
acts[i] <- with(data.frame(rbind(pars)), eval(parse(text = right_from[2])))
}
acts
acts <- sapply(act_funs, function(pform) eval(pform[[length(pform)]], envir = as.list(pars)))

num_options <- m3_model$other_vars$num_options
if (tolower(m3_model$other_vars$choice_rule) == "simple") {
probs <- (acts*num_options)/sum(acts*num_options)
} else if (tolower(m3_model$other_vars$choice_rule) == "softmax") {
probs <- (exp(acts)*num_options)/sum(exp(acts)*num_options)
} else {
stop2(glue("Unsupported choice rule: \" ", m3_model$other_vars$choice_rule,"\"\n",
"Please select either \"simple\" or \"softmax\" as choice_rule."))
num_options <- m3_model$other_vars$num_options
choice_rule <- tolower(m3_model$other_vars$choice_rule)
if (choice_rule == "softmax") acts <- exp(acts)
acts <- acts * num_options
probs <- acts / sum(acts)
}

t(rmultinom(n, size = size, prob = probs))
}
91 changes: 55 additions & 36 deletions R/model_m3.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@
#' @export
m3 <- function(resp_cats, num_options, choice_rule = "softmax", version = "custom", ...) {
stop_missing_args()
stopif(
!version %in% c("custom", "cs", "ss"),
'Unknown version: {version}. It should be one of "ss", "cs" or "custom"'
)
stopif(
!tolower(choice_rule) %in% c("softmax", "simple"),
'Unsupported choice rule "{choice_rule}. Must be one of "simple" or "softmax"'
)

.model_m3(
resp_cats = resp_cats, num_options = num_options,
choice_rule = choice_rule, version = version, ...
Expand Down Expand Up @@ -341,7 +350,7 @@ check_data.m3 <- function(model, data, formula) {
#' @export
check_formula.m3 <- function(model, data, formula) {
if (model$version != "custom") {
formula <- act_funs_m3versions(model, warnings = FALSE) + formula
formula <- construct_m3_act_funs(model, warnings = FALSE) + formula
}

formula <- apply_links(formula, model$links)
Expand Down Expand Up @@ -470,7 +479,6 @@ configure_model.m3 <- function(model, data, formula) {
#' Measurement Model (m3) implemented in the `bmm` package. If no `bmmodel` object is
#' passed then it will print the available model versions.
#'
#'
#' @param model A bmmodel object that specifies the M3 model for which the
#' activation functions should be generated. If no model is passed the available
#' M3 versions will be printed to the console.
Expand All @@ -484,55 +492,66 @@ configure_model.m3 <- function(model, data, formula) {
#' @examplesIf isTRUE(Sys.getenv("BMM_EXAMPLES"))
#' model <- m3(
#' resp_cats = c("correct","other", "npl"),
#' num_options = c(1,4,5),
#' num_options = c(1, 4, 5),
#' version = "ss"
#' )
#'
#' ss_act_funs <- act_funs_m3versions(model, warnings = FALSE)
#'
#' ss_act_funs <- construct_m3_act_funs(model, warnings = FALSE)
#'
#' @export
act_funs_m3versions <- function(model = NULL, warnings = TRUE) {
construct_m3_act_funs <- function(model = NULL, warnings = TRUE) {
if (is.null(model)) {
msg <- glue("Available m3 versions with pre-defined activation functions are:\n",
" - \"ss\" for simple span tasks: 3 response categories (correct, other, npl) \n" ,
" - \"cs\" for complex span tasks. 5 response categories (correct, dist_context, other, dist_other, npl)")
return(print(msg))
message2(
'Available m3 versions with pre-defined activation functions are:
- "ss" for simple span tasks: 3 response categories (correct, other, npl)
- "cs" for complex span tasks. 5 response categories (correct, dist_context, other, dist_other, npl)'
)
return(invisible())
}

stopif(!"m3" %in% class(model),
glue("Activation functions can only be generated for \"m3\" models."))
stopif(
!inherits(model, "m3") || !model$version %in% c("ss", "cs"),
'Activation functions can only be generated for "m3" models "ss" and "cs"'
)

resp_cats <- model$resp_vars$resp_cats
if (model$version == "ss") {
warnif(warnings,
glue("\n","The \"ss\" version of the m3 requires that response categories are ordered as follows:\n",
" 1) correct: correct responses\n",
" 2) other: other list responses\n",
" 3) npl: not presented lures"))
warnif(
warnings,
glue(
'\nThe "ss" version of the m3 requires that response categories are ordered as follows:
1) correct: correct responses
2) other: other list responses
3) npl: not presented lures'
)
)

act_funs <- bmf(
formula(glue(model$resp_vars$resp_cats[1], "~ b + a + c")),
formula(glue(model$resp_vars$resp_cats[2], "~ b + a")),
formula(glue(model$resp_vars$resp_cats[3], "~ b"))
formula(glue("{resp_cats[1]} ~ b + a + c")),
formula(glue("{resp_cats[2]} ~ b + a")),
formula(glue("{resp_cats[3]} ~ b"))
)
} else if (model$version == "cs") {
warnif(warnings,
glue("\n","The \"cs\" version of the m3 requires that response categories are ordered as follows:\n",
" 1) correct: correct responses\n",
" 2) dist_context: distractor responses close in context to the correct item\n",
" 3) other: other list responses\n",
" 4) dist_other: all distractor responses not close in context to the correct item\n",
" 5) npl: not presented lures"))
warnif(
warnings,
glue(
"\n", "The \"cs\" version of the m3 requires that response categories are ordered as follows:\n",
" 1) correct: correct responses\n",
" 2) dist_context: distractor responses close in context to the correct item\n",
" 3) other: other list responses\n",
" 4) dist_other: all distractor responses not close in context to the correct item\n",
" 5) npl: not presented lures"
)
)

act_funs <- bmf(
formula(glue(model$resp_vars$resp_cats[1], "~ b + a + c")),
formula(glue(model$resp_vars$resp_cats[2], "~ b + f * a + f * c")),
formula(glue(model$resp_vars$resp_cats[3], "~ b + a")),
formula(glue(model$resp_vars$resp_cats[4], "~ b + f * a")),
formula(glue(model$resp_vars$resp_cats[5], "~ b"))
formula(glue("{resp_cats[1]} ~ b + a + c")),
formula(glue("{resp_cats[2]} ~ b + f * a + f * c")),
formula(glue("{resp_cats[3]} ~ b + a")),
formula(glue("{resp_cats[4]} ~ b + f * a")),
formula(glue("{resp_cats[5]} ~ b"))
)
} else {
act_funs <- bmf()
}
}

return(act_funs)
reset_env(act_funs)
}
10 changes: 5 additions & 5 deletions man/act_funs_m3versions.Rd → man/construct_m3_act_funs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fb527a5

Please sign in to comment.