From 8b7f8444585feb54805c5897a766090f69933f0e Mon Sep 17 00:00:00 2001 From: Fiona Seaton Date: Tue, 27 Feb 2024 10:18:46 +0000 Subject: [PATCH 1/5] add zip distr Plus also fix some ndraws bugs in pp functions --- DESCRIPTION | 2 +- R/jsdm_stancode.R | 58 ++++++++++---- R/posterior_predict.R | 112 ++++++++++++++++----------- R/pp_check.R | 2 +- R/prior.R | 11 ++- R/sim_data_funs.R | 25 ++++-- R/stan_jsdm.R | 12 +-- man/jsdm_prior.Rd | 6 +- man/jsdm_sim_data.Rd | 3 +- man/jsdm_stancode.Rd | 6 +- man/posterior_linpred.jsdmStanFit.Rd | 3 +- man/stan_gllvm.Rd | 7 +- man/stan_jsdm.Rd | 7 +- man/stan_mglmm.Rd | 7 +- 14 files changed, 171 insertions(+), 90 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a92b61d..7d76321 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: jsdmstan Title: Fitting jSDMs in Stan -Version: 0.3.0 +Version: 0.3.0.9000 Authors@R: person("Fiona", "Seaton", , "fseaton@ceh.ac.uk", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-2022-7451")) diff --git a/R/jsdm_stancode.R b/R/jsdm_stancode.R index 3a02b5a..0fc41dc 100644 --- a/R/jsdm_stancode.R +++ b/R/jsdm_stancode.R @@ -13,8 +13,10 @@ #' can be modified using the prior object. #' #' @param method The method, one of \code{"gllvm"} or \code{"mglmm"} -#' @param family The family, one of "\code{"gaussian"}, \code{"bernoulli"}, -#' \code{"poisson"} or \code{"neg_binomial"} +#' @param family is the response family, must be one of \code{"gaussian"}, +#' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +#' \code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +#' matching is supported. #' @param prior The prior, given as the result of a call to [jsdm_prior()] #' @param log_lik Whether the log likelihood should be calculated in the generated #' quantities (by default \code{TRUE}), required for loo @@ -37,7 +39,7 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), beta_param = "cor") { # checks family <- match.arg(family, c("gaussian", "bernoulli", "poisson", - "neg_binomial","binomial")) + "neg_binomial","binomial","zero_inflated_poisson")) method <- match.arg(method, c("gllvm", "mglmm")) beta_param <- match.arg(beta_param, c("cor","unstruct")) site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped")) @@ -81,6 +83,7 @@ ifelse(site_intercept == "grouped", "bernoulli" = "int", "neg_binomial" = "int", "poisson" = "int", + "zero_inflated_poisson" = "int", "binomial" = "int" ), "Y[N,S]; //Species matrix", ifelse(family == "binomial", @@ -131,7 +134,9 @@ ifelse(site_intercept == "grouped", "bernoulli" = "", "neg_binomial" = " real kappa[S]; // neg_binomial parameters", - "poisson" = "" + "poisson" = "", + "zero_inflated_poisson" = " + real zi[S]; // zero-inflation parameter" ) pars <- paste( @@ -263,10 +268,14 @@ ifelse(site_intercept == "grouped", "), "bern" = "", "poisson" = "", - "binomial" = "" + "binomial" = "", + "zero_inflated_poisson" = paste(" + //zero-inflation parameter + zi ~ ", prior[["zi"]], "; +") ) ) - model_pt2 <- paste( + model_pt2 <- if(family != "zero_inflated_poisson"){ paste( " for(i in 1:N) Y[i,] ~ ", switch(family, @@ -276,7 +285,21 @@ ifelse(site_intercept == "grouped", "poisson" = "poisson_log(mu[i,]);", "binomial" = "binomial_logit(Ntrials[i], mu[i,]);" ) - ) + )} else{" + for(n in 1:N){ + for(s in 1:S){ + if (Y[n,s] == 0){ + target += log_sum_exp(bernoulli_lpmf(1 | zi[s]), + bernoulli_lpmf(0 |zi[s]) + + poisson_log_lpmf(Y[n,s] | mu[n,s])); + } else { + target += bernoulli_lpmf(0 | zi[s]) + + poisson_log_lpmf(Y[n,s] | mu[n,s]); + } + } + } +" + } generated_quantities <- paste( ifelse(isTRUE(log_lik), " @@ -322,14 +345,21 @@ ifelse(site_intercept == "grouped", ), "; "))," for(i in 1:N) { - for(j in 1:S) { - log_lik[i, j] = ", + for(j in 1:S) {", switch(family, - "gaussian" = "normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);", - "bernoulli" = "bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);", - "neg_binomial" = "neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);", - "poisson" = "poisson_log_lpmf(Y[i, j] | linpred[i, j]);", - "binomial" = "binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);" + "gaussian" = "log_lik[i, j] = normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);", + "bernoulli" = "log_lik[i, j] = bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);", + "neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);", + "poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);", + "binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);", + "zero_inflated_poisson" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), + bernoulli_lpmf(0 |zi[j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j])); + } else { + log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }" )," } } diff --git a/R/posterior_predict.R b/R/posterior_predict.R index 50b58da..082df90 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -32,7 +32,8 @@ #' list will have length equal to the number of species with each element of #' the list being a draws x sites matrix. If the list_index is \code{"sites"} the #' list will have length equal to the number of sites with each element of the -#' list being a draws x species matrix. +#' list being a draws x species matrix. Note that in the zero-inflated case this is +#' only the linear predictor of the non-zero-inflated part of the model. #' #' @seealso [posterior_predict.jsdmStanFit()] #' @@ -95,39 +96,17 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, } model_est <- extract(object, pars = model_pars) - n_iter <- dim(model_est[[1]])[1] - if (!is.null(draw_ids)) { - if (max(draw_ids) > n_iter) { - stop(paste( - "Maximum of draw_ids (", max(draw_ids), - ") is greater than number of iterations (", n_iter, ")" - )) - } + draw_id <- draw_id_check(draw_ids = draw_ids, n_iter = n_iter, ndraws = ndraws) - draw_id <- draw_ids - } else { - if (!is.null(ndraws)) { - if (n_iter < ndraws) { - warning(paste( - "There are fewer samples than ndraws specified, defaulting", - "to using all iterations" - )) - ndraws <- n_iter - } - draw_id <- sample.int(n_iter, ndraws) - model_est <- lapply(model_est, function(x) { - switch(length(dim(x)), - `1` = x[draw_id, drop = FALSE], - `2` = x[draw_id, , drop = FALSE], - `3` = x[draw_id, , , drop = FALSE] - ) - }) - } else { - draw_id <- seq_len(n_iter) - } - } + model_est <- lapply(model_est, function(x) { + switch(length(dim(x)), + `1` = x[draw_id, drop = FALSE], + `2` = x[draw_id, , drop = FALSE], + `3` = x[draw_id, , , drop = FALSE] + ) + }) model_pred_list <- lapply(seq_along(draw_id), function(d) { if (method == "gllvm") { @@ -162,7 +141,8 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, "bernoulli" = inv_logit(x), "poisson" = exp(x), "neg_binomial" = exp(x), - "binomial" = inv_logit(x) + "binomial" = inv_logit(x), + "zero_inflated_poisson" = exp(x) ) }) } @@ -208,24 +188,37 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, list_index = "draws", Ntrials = NULL, ...) { transform <- ifelse(object$family == "gaussian", FALSE, TRUE) + if (!is.null(ndraws) & !is.null(draw_ids)) { + message("Both ndraws and draw_ids have been specified, ignoring ndraws") + } + if (!is.null(draw_ids)) { + if (any(!is.wholenumber(draw_ids))) { + stop("draw_ids must be a vector of positive integers") + } + } + n_iter <- length(object$fit@stan_args)*(object$fit@stan_args[[1]]$iter - + object$fit@stan_args[[1]]$warmup) + draw_id <- draw_id_check(draw_ids = draw_ids, n_iter = n_iter, ndraws = ndraws) + + post_linpred <- posterior_linpred(object, - newdata = newdata, ndraws = ndraws, - newdata_type = newdata_type, draw_ids = draw_ids, + newdata = newdata, + newdata_type = newdata_type, draw_ids = draw_id, transform = transform, list_index = "draws" ) if (object$family == "gaussian") { - mod_sigma <- rstan::extract(object$fit, pars = "sigma", permuted = FALSE) - } - if (object$family == "neg_binomial") { - mod_kappa <- rstan::extract(object$fit, pars = "kappa", permuted = FALSE) - } - if(object$family == "binomial"){ + mod_sigma <- extract(object, pars = "sigma")[[1]][draw_id,] + } else if (object$family == "neg_binomial") { + mod_kappa <- extract(object, pars = "kappa")[[1]][draw_id,] + } else if(object$family == "binomial"){ if(is.null(newdata)) { Ntrials <- object$data_list$Ntrials } else { Ntrials <- ntrials_check(Ntrials, nrow(newdata)) } + } else if(object$family == "zero_inflated_poisson"){ + mod_zi <- extract(object, pars = "zi")[[1]][draw_id,] } n_sites <- length(object$sites) @@ -243,11 +236,13 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, } else { for(i in seq_len(nrow(x2))){ for(j in seq_len(ncol(x2))){ - x2[i,j] <- switch(object$family, - "gaussian" = stats::rnorm(1, x2[i,j], mod_sigma[j]), - "bernoulli" = stats::rbinom(1, 1, x2[i,j]), - "poisson" = stats::rpois(1, x2[i,j]), - "neg_binomial" = rgampois(1, x2[i,j], mod_kappa[j]) + x2[i,j] <- switch( + object$family, + "gaussian" = stats::rnorm(1, x2[i,j], mod_sigma[x,j]), + "bernoulli" = stats::rbinom(1, 1, x2[i,j]), + "poisson" = stats::rpois(1, x2[i,j]), + "neg_binomial" = rgampois(1, x2[i,j], mod_kappa[x,j]), + "zero_inflated_poisson" = stats::rbinom(1, 1, mod_zi[x,j])*stats::rpois(1, x2[i,j]) ) } } @@ -297,3 +292,30 @@ switch_indices <- function(res_list, list_index) { stop("List index not valid") } } + +draw_id_check <- function(draw_ids, n_iter, ndraws){ + if (!is.null(draw_ids)) { + if (max(draw_ids) > n_iter) { + stop(paste( + "Maximum of draw_ids (", max(draw_ids), + ") is greater than number of iterations (", n_iter, ")" + )) + } + + draw_id <- draw_ids + } else { + if (!is.null(ndraws)) { + if (n_iter < ndraws) { + warning(paste( + "There are fewer samples than ndraws specified, defaulting", + "to using all iterations" + )) + ndraws <- n_iter + } + draw_id <- sample.int(n_iter, ndraws) + } else { + draw_id <- seq_len(n_iter) + } + } + return(draw_id) +} diff --git a/R/pp_check.R b/R/pp_check.R index 205d474..a32fe45 100644 --- a/R/pp_check.R +++ b/R/pp_check.R @@ -357,7 +357,7 @@ multi_pp_check <- function(object, plotfun = "dens_overlay", species = NULL, post_args$object <- object post_args$list_index <- "draws" post_args$draw_ids <- draw_ids - post_args$ndraws <- ndraws + # post_args$ndraws <- ndraws post_res <- do.call(post_fun, post_args) diff --git a/R/prior.R b/R/prior.R index 3108dcf..4a49119 100644 --- a/R/prior.R +++ b/R/prior.R @@ -54,6 +54,8 @@ #' to be positive (default standard normal) #' @param kappa For negative binomial response, the negative binomial variance #' parameter. Constrained to be positive (default standard normal) +#' @param zi For zero-inflated poisson, the proportion of inflated zeros (default +#' beta distribution with both alpha and beta parameters set to 1). #' #' @return An object of class \code{"jsdmprior"} taking the form of a named list #' @export @@ -76,7 +78,8 @@ jsdm_prior <- function(sigmas_preds = "normal(0,1)", L = "normal(0,1)", sigma_L = "normal(0,1)", sigma = "normal(0,1)", - kappa = "normal(0,1)") { + kappa = "normal(0,1)", + zi = "beta(1,1)") { res <- list( sigmas_preds = sigmas_preds, z_preds = z_preds, cor_preds = cor_preds, betas = betas, @@ -84,7 +87,7 @@ jsdm_prior <- function(sigmas_preds = "normal(0,1)", sigmas_species = sigmas_species, z_species = z_species, cor_species = cor_species, LV = LV, L = L, sigma_L = sigma_L, - sigma = sigma, kappa = kappa + sigma = sigma, kappa = kappa, zi = zi ) if (!(all(sapply(res, is.character)))) { stop("All arguments must be supplied as character vectors") @@ -107,11 +110,11 @@ print.jsdmprior <- function(x, ...) { rep("site_intercept", 3), rep("mglmm", 3), rep("gllvm", 3), - "gaussian", "neg_binomial" + "gaussian", "neg_binomial","zero_inflated_poisson" ), Constraint = c( "lower=0", rep("none", 5), rep("lower=0", 2), - rep("none", 4), rep("lower=0", 3) + rep("none", 4), rep("lower=0", 3),"lower=0,upper=1" ), Prior = unlist(unname(x)) ) diff --git a/R/sim_data_funs.R b/R/sim_data_funs.R index 44b151d..9766f5f 100644 --- a/R/sim_data_funs.R +++ b/R/sim_data_funs.R @@ -37,7 +37,8 @@ #' #' @param family is the response family, must be one of \code{"gaussian"}, #' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -#' or \code{"bernoulli"}. Regular expression matching is supported. +#' \code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +#' matching is supported. #' #' @param method is the jSDM method to use, currently either \code{"gllvm"} or #' \code{"mglmm"} - see details for more information. @@ -65,7 +66,7 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m beta_param = "unstruct", prior = jsdm_prior()) { response <- match.arg(family, c("gaussian", "neg_binomial", "poisson", - "bernoulli", "binomial")) + "bernoulli", "binomial", "zero_inflated_poisson")) site_intercept <- match.arg(site_intercept, c("none","ungrouped","grouped")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) if(site_intercept == "grouped"){ @@ -109,7 +110,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "lkj_corr", "student_t", "cauchy", - "gamma" + "gamma", + "beta" ) }))) { stop(paste( @@ -132,7 +134,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "lkj_corr" = "rlkj", "student_t" = "rstudentt", "cauchy" = "rcauchy", - "gamma" = "rgamma" + "gamma" = "rgamma", + "beta" = "rbeta" ) fun_arg1 <- switch(x, "sigmas_preds" = K + 1 * species_intercept, @@ -149,7 +152,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "L" = D1 * (S - D1) + (D1 * (D1 - 1) / 2) + D1, "sigma_L" = 1, "sigma" = S, - "kappa" = S + "kappa" = S, + "zi" = S ) fun_args <- as.list(c(fun_arg1, as.numeric(unlist(y[[1]][[1]])[-1]))) @@ -296,6 +300,11 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m match.fun(prior_func[["kappa"]][[1]]), prior_func[["kappa"]][[2]] )) + } else if (response == "zero_inflated_poisson") { + zi <- do.call( + match.fun(prior_func[["zi"]][[1]]), + prior_func[["zi"]][[2]] + ) } # print(str(sigma)) @@ -316,7 +325,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "gaussian" = stats::rnorm(1, mu_ij, sigma[j]), "poisson" = stats::rpois(1, exp(mu_ij)), "bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij)), - "binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij)) + "binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij)), + "zero_inflated_poisson" = stats::rbinom(1, 1, zi[j])*stats::rpois(1, exp(mu_ij)) ) } } @@ -362,6 +372,9 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m if(response == "neg_binomial"){ pars$kappa <- kappa } + if(response == "zero_inflated_poisson"){ + pars$zi <- zi + } if (isTRUE(species_intercept)) { if (K > 0) { x <- x[, 2:ncol(x)] diff --git a/R/stan_jsdm.R b/R/stan_jsdm.R index c8c91e9..8089c14 100644 --- a/R/stan_jsdm.R +++ b/R/stan_jsdm.R @@ -26,9 +26,10 @@ #' #' @param D The number of latent variables within a GLLVM model #' -#' @param family The response family for the model, required to be one of -#' \code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -#' or \code{"neg_binomial"} +#' @param family is the response family, must be one of \code{"gaussian"}, +#' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +#' \code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +#' matching is supported. #' #' @param species_intercept Whether the model should be fit with an intercept by #' species, by default \code{TRUE} @@ -105,7 +106,7 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth beta_param = "unstruct", Ntrials = NULL, save_data = TRUE, iter = 4000, log_lik = TRUE, ...) { family <- match.arg(family, c("gaussian", "bernoulli", "poisson", - "neg_binomial","binomial")) + "neg_binomial","binomial", "zero_inflated_poisson")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) stopifnot( @@ -341,7 +342,8 @@ validate_data <- function(Y, D, X, species_intercept, ))) { stop("Y matrix is not binary") } - } else if (family %in% c("poisson", "neg_binomial", "binomial")) { + } else if (family %in% c("poisson", "neg_binomial", "binomial", + "zero_inflated_poisson")) { if (!any(apply(data_list$Y, 1:2, is.wholenumber))) { stop("Y matrix is not composed of integers") } diff --git a/man/jsdm_prior.Rd b/man/jsdm_prior.Rd index 33d21b9..7d415b8 100644 --- a/man/jsdm_prior.Rd +++ b/man/jsdm_prior.Rd @@ -20,7 +20,8 @@ jsdm_prior( L = "normal(0,1)", sigma_L = "normal(0,1)", sigma = "normal(0,1)", - kappa = "normal(0,1)" + kappa = "normal(0,1)", + zi = "beta(1,1)" ) \method{print}{jsdmprior}(x, ...) @@ -69,6 +70,9 @@ to be positive (default standard normal)} \item{kappa}{For negative binomial response, the negative binomial variance parameter. Constrained to be positive (default standard normal)} +\item{zi}{For zero-inflated poisson, the proportion of inflated zeros (default +beta distribution with both alpha and beta parameters set to 1).} + \item{x}{Object of class \code{jsdmprior}} \item{...}{Currently unused} diff --git a/man/jsdm_sim_data.Rd b/man/jsdm_sim_data.Rd index 30fa5df..efafa0b 100644 --- a/man/jsdm_sim_data.Rd +++ b/man/jsdm_sim_data.Rd @@ -35,7 +35,8 @@ mglmm_sim_data(...) \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -or \code{"bernoulli"}. Regular expression matching is supported.} +\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +matching is supported.} \item{method}{is the jSDM method to use, currently either \code{"gllvm"} or \code{"mglmm"} - see details for more information.} diff --git a/man/jsdm_stancode.Rd b/man/jsdm_stancode.Rd index a37f910..fb0d5e2 100644 --- a/man/jsdm_stancode.Rd +++ b/man/jsdm_stancode.Rd @@ -19,8 +19,10 @@ jsdm_stancode( \arguments{ \item{method}{The method, one of \code{"gllvm"} or \code{"mglmm"}} -\item{family}{The family, one of "\code{"gaussian"}, \code{"bernoulli"}, -\code{"poisson"} or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +matching is supported.} \item{prior}{The prior, given as the result of a call to \code{\link[=jsdm_prior]{jsdm_prior()}}} diff --git a/man/posterior_linpred.jsdmStanFit.Rd b/man/posterior_linpred.jsdmStanFit.Rd index 437112b..8d1e130 100644 --- a/man/posterior_linpred.jsdmStanFit.Rd +++ b/man/posterior_linpred.jsdmStanFit.Rd @@ -46,7 +46,8 @@ the list being a site x species matrix. If the list_index is \code{"species"} th list will have length equal to the number of species with each element of the list being a draws x sites matrix. If the list_index is \code{"sites"} the list will have length equal to the number of sites with each element of the -list being a draws x species matrix. +list being a draws x species matrix. Note that in the zero-inflated case this is +only the linear predictor of the non-zero-inflated part of the model. } \description{ Extract the posterior draws of the linear predictor, possibly transformed by diff --git a/man/stan_gllvm.Rd b/man/stan_gllvm.Rd index 37ede1f..9755ab6 100644 --- a/man/stan_gllvm.Rd +++ b/man/stan_gllvm.Rd @@ -42,9 +42,10 @@ species, by default \code{TRUE}} Y, X, N, S, K, and site_intercept. See output of \code{\link[=jsdm_sim_data]{jsdm_sim_data()}} for an example of how this can be formatted.} -\item{family}{The response family for the model, required to be one of -\code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept diff --git a/man/stan_jsdm.Rd b/man/stan_jsdm.Rd index 478f0f1..93c3f5e 100644 --- a/man/stan_jsdm.Rd +++ b/man/stan_jsdm.Rd @@ -47,9 +47,10 @@ species, by default \code{TRUE}} Y, X, N, S, K, and site_intercept. See output of \code{\link[=jsdm_sim_data]{jsdm_sim_data()}} for an example of how this can be formatted.} -\item{family}{The response family for the model, required to be one of -\code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept diff --git a/man/stan_mglmm.Rd b/man/stan_mglmm.Rd index a25bb1a..16ad027 100644 --- a/man/stan_mglmm.Rd +++ b/man/stan_mglmm.Rd @@ -39,9 +39,10 @@ species, by default \code{TRUE}} Y, X, N, S, K, and site_intercept. See output of \code{\link[=jsdm_sim_data]{jsdm_sim_data()}} for an example of how this can be formatted.} -\item{family}{The response family for the model, required to be one of -\code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept From d54ec97140b0d4f0d745058923a2458253d0a96b Mon Sep 17 00:00:00 2001 From: Fiona Seaton Date: Mon, 13 May 2024 17:15:03 +0100 Subject: [PATCH 2/5] Add zi negbin --- DESCRIPTION | 2 +- R/jsdm_stancode.R | 99 +++++++++++++++----- R/{jsdmstan-package.R => jsdmstan_PACKAGE.R} | 1 - R/posterior_predict.R | 61 +++++++----- R/pp_check.R | 9 +- R/prior.R | 7 +- R/sim_data_funs.R | 29 +++++- R/stan_jsdm.R | 28 +++++- man/jsdm_prior.Rd | 5 +- man/jsdm_sim_data.Rd | 3 +- man/jsdm_stancode.Rd | 3 +- man/jsdmstan-package.Rd | 3 +- man/posterior_predict.jsdmStanFit.Rd | 22 +++-- man/stan_gllvm.Rd | 2 +- man/stan_jsdm.Rd | 2 +- man/stan_mglmm.Rd | 2 +- 16 files changed, 197 insertions(+), 81 deletions(-) rename R/{jsdmstan-package.R => jsdmstan_PACKAGE.R} (97%) diff --git a/DESCRIPTION b/DESCRIPTION index 7d76321..d272ba4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -12,7 +12,7 @@ License: GPL (>= 3) Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Biarch: true Depends: R (>= 3.4.0) diff --git a/R/jsdm_stancode.R b/R/jsdm_stancode.R index 0fc41dc..6aeed29 100644 --- a/R/jsdm_stancode.R +++ b/R/jsdm_stancode.R @@ -15,7 +15,8 @@ #' @param method The method, one of \code{"gllvm"} or \code{"mglmm"} #' @param family is the response family, must be one of \code{"gaussian"}, #' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -#' \code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +#' \code{"bernoulli"}, \code{"zi_poisson"}, or +#' \code{"zi_neg_binomial"}. Regular expression #' matching is supported. #' @param prior The prior, given as the result of a call to [jsdm_prior()] #' @param log_lik Whether the log likelihood should be calculated in the generated @@ -39,7 +40,8 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), beta_param = "cor") { # checks family <- match.arg(family, c("gaussian", "bernoulli", "poisson", - "neg_binomial","binomial","zero_inflated_poisson")) + "neg_binomial","binomial","zi_poisson", + "zi_neg_binomial")) method <- match.arg(method, c("gllvm", "mglmm")) beta_param <- match.arg(beta_param, c("cor","unstruct")) site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped")) @@ -83,13 +85,23 @@ ifelse(site_intercept == "grouped", "bernoulli" = "int", "neg_binomial" = "int", "poisson" = "int", - "zero_inflated_poisson" = "int", + "zi_poisson" = "int", + "zi_neg_binomial" = "int", "binomial" = "int" ), "Y[N,S]; //Species matrix", ifelse(family == "binomial", " - int Ntrials[N]; // Number of trials","") - ) + int Ntrials[N]; // Number of trials",""), + ifelse(grepl("zi_", family)," + int N_zero[S]; // number of zeros per species + int N_nonzero[S]; //number of nonzeros per species + int Sum_nonzero; //Total number of nonzeros across all species + int Sum_zero; //Total number of zeros across all species + int Y_nz[Sum_nonzero]; //Y values for nonzeros + int ss[Sum_nonzero]; //species index for Y_nz + int nn[Sum_nonzero]; //site index for Y_nz + int sz[Sum_zero]; //species index for Y_z + int nz[Sum_zero]; //site index for Y_z","")) transformed_data <- ifelse(method == "gllvm", " // Ensures identifiability of the model - no rotation of factors int M; @@ -135,7 +147,10 @@ ifelse(site_intercept == "grouped", "neg_binomial" = " real kappa[S]; // neg_binomial parameters", "poisson" = "", - "zero_inflated_poisson" = " + "zi_poisson" = " + real zi[S]; // zero-inflation parameter", + "zi_neg_binomial" = " + real kappa[S]; // neg_binomial parameters real zi[S]; // zero-inflation parameter" ) @@ -220,10 +235,22 @@ ifelse(site_intercept == "grouped", ") model <- paste(" matrix[N,S] mu; - ", switch(method, + ", ifelse(grepl("zi_",family)," + real mu_nz[Sum_nonzero]; + real mu_z[Sum_zero]; + int pos; + int neg;",""), + switch(method, "gllvm" = gllvm_model, "mglmm" = mglmm_model - )) + ),ifelse(grepl("zi_",family)," + for(i in 1:Sum_nonzero){ + mu_nz[i] = mu[nn[i],ss[i]]; + } + for(i in 1:Sum_zero){ + mu_z[i] = mu[nz[i],sz[i]]; + } + ","")) model_priors <- paste( ifelse(site_intercept %in% c("ungrouped","grouped"), paste(" // Site-level intercept priors @@ -269,13 +296,18 @@ ifelse(site_intercept == "grouped", "bern" = "", "poisson" = "", "binomial" = "", - "zero_inflated_poisson" = paste(" + "zi_poisson" = paste(" + //zero-inflation parameter + zi ~ ", prior[["zi"]], "; +"), +"zi_neg_binomial" = paste(" //zero-inflation parameter zi ~ ", prior[["zi"]], "; + kappa ~ ", prior[["kappa"]], "; ") ) ) - model_pt2 <- if(family != "zero_inflated_poisson"){ paste( + model_pt2 <- if(!grepl("zi_", family)){ paste( " for(i in 1:N) Y[i,] ~ ", switch(family, @@ -285,20 +317,29 @@ ifelse(site_intercept == "grouped", "poisson" = "poisson_log(mu[i,]);", "binomial" = "binomial_logit(Ntrials[i], mu[i,]);" ) - )} else{" - for(n in 1:N){ - for(s in 1:S){ - if (Y[n,s] == 0){ - target += log_sum_exp(bernoulli_lpmf(1 | zi[s]), - bernoulli_lpmf(0 |zi[s]) - + poisson_log_lpmf(Y[n,s] | mu[n,s])); - } else { - target += bernoulli_lpmf(0 | zi[s]) - + poisson_log_lpmf(Y[n,s] | mu[n,s]); - } - } - } -" + )} else{paste(" + pos = 1; + neg = 1; + for(s in 1:S){ + target + += N_zero[s] + * log_sum_exp(log(zi[s]), + log1m(zi[s]) + +", + switch(family, + "zi_poisson" = "poisson_log_lpmf(0 | segment(mu_z, neg, N_zero[s])));", + "zi_neg_binomial" = "neg_binomial_2_log_lpmf(0 | segment(mu_z, neg, N_zero[s]), kappa[s]));")," + target += N_nonzero[s] * log1m(zi[s]); + target +=", + switch(family, + "zi_poisson" = "poisson_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) | + segment(mu_nz, pos, N_nonzero[s]));", + "zi_neg_binomial" = "neg_binomial_2_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) | + segment(mu_nz, pos, N_nonzero[s]), kappa[s]);")," + pos = pos + N_nonzero[s]; + neg = neg + N_zero[s]; + } +") } generated_quantities <- paste( @@ -352,13 +393,21 @@ ifelse(site_intercept == "grouped", "neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);", "poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);", "binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);", - "zero_inflated_poisson" = "if (Y[i,j] == 0){ + "zi_poisson" = "if (Y[i,j] == 0){ log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), bernoulli_lpmf(0 |zi[j]) + poisson_log_lpmf(Y[i,j] | linpred[i,j])); } else { log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }", + "zi_neg_binomial" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), + bernoulli_lpmf(0 |zi[j]) + + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j])); + } else { + log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]); }" )," } diff --git a/R/jsdmstan-package.R b/R/jsdmstan_PACKAGE.R similarity index 97% rename from R/jsdmstan-package.R rename to R/jsdmstan_PACKAGE.R index 6d00119..fa7d9c1 100644 --- a/R/jsdmstan-package.R +++ b/R/jsdmstan_PACKAGE.R @@ -9,7 +9,6 @@ #' Summary functions are provided, as are interfaces to the \pkg{bayesplot} #' plotting functions #' -#' @docType package #' @name jsdmstan-package #' @aliases jsdmstan #' @import Rcpp diff --git a/R/posterior_predict.R b/R/posterior_predict.R index 082df90..908af1f 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -142,7 +142,8 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, "poisson" = exp(x), "neg_binomial" = exp(x), "binomial" = inv_logit(x), - "zero_inflated_poisson" = exp(x) + "zi_poisson" = exp(x), + "zi_neg_binomial" = exp(x) ) }) } @@ -157,36 +158,40 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, return(model_pred_list) } -#' Draw from the posterior predictive distribution +#'Draw from the posterior predictive distribution #' -#' Draw from the posterior predictive distribution of the outcome. +#'Draw from the posterior predictive distribution of the outcome. #' -#' @aliases posterior_predict +#'@aliases posterior_predict #' -#' @inheritParams posterior_linpred.jsdmStanFit +#'@inheritParams posterior_linpred.jsdmStanFit #' -#' @param Ntrials For the binomial distribution the number of trials, given as -#' either a single integer which is assumed to be constant across sites or as -#' a site-length vector of integers. +#'@param Ntrials For the binomial distribution the number of trials, given as either +#' a single integer which is assumed to be constant across sites or as a site-length +#' vector of integers. #' -#' @return A list of linear predictors. If list_index is \code{"draws"} (the default) -#' the list will have length equal to the number of draws with each element of -#' the list being a site x species matrix. If the list_index is \code{"species"} the -#' list will have length equal to the number of species with each element of -#' the list being a draws x sites matrix. If the list_index is \code{"sites"} the -#' list will have length equal to the number of sites with each element of -#' the list being a draws x species matrix. +#'@param include_zi For the zero-inflated poisson distribution, whether to include +#' the zero-inflation in the prediction. Defaults to \code{TRUE}. #' -#' @seealso [posterior_linpred.jsdmStanFit()] +#'@return A list of linear predictors. If list_index is \code{"draws"} (the default) +#' the list will have length equal to the number of draws with each element of the +#' list being a site x species matrix. If the list_index is \code{"species"} the +#' list will have length equal to the number of species with each element of the +#' list being a draws x sites matrix. If the list_index is \code{"sites"} the list +#' will have length equal to the number of sites with each element of the list being +#' a draws x species matrix. #' -#' @importFrom rstantools posterior_predict -#' @export posterior_predict -#' @export +#'@seealso [posterior_linpred.jsdmStanFit()] +#' +#'@importFrom rstantools posterior_predict +#'@export posterior_predict +#'@export posterior_predict.jsdmStanFit <- function(object, newdata = NULL, newdata_type = "X", ndraws = NULL, draw_ids = NULL, list_index = "draws", - Ntrials = NULL, ...) { + Ntrials = NULL, + include_zi = TRUE, ...) { transform <- ifelse(object$family == "gaussian", FALSE, TRUE) if (!is.null(ndraws) & !is.null(draw_ids)) { message("Both ndraws and draw_ids have been specified, ignoring ndraws") @@ -217,8 +222,11 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, } else { Ntrials <- ntrials_check(Ntrials, nrow(newdata)) } - } else if(object$family == "zero_inflated_poisson"){ + } else if(object$family == "zi_poisson"){ mod_zi <- extract(object, pars = "zi")[[1]][draw_id,] + } else if(object$family == "zi_neg_binomial"){ + mod_zi <- extract(object, pars = "zi")[[1]][draw_id,] + mod_kappa <- extract(object, pars = "kappa")[[1]][draw_id,] } n_sites <- length(object$sites) @@ -242,7 +250,16 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, "bernoulli" = stats::rbinom(1, 1, x2[i,j]), "poisson" = stats::rpois(1, x2[i,j]), "neg_binomial" = rgampois(1, x2[i,j], mod_kappa[x,j]), - "zero_inflated_poisson" = stats::rbinom(1, 1, mod_zi[x,j])*stats::rpois(1, x2[i,j]) + "zi_poisson" = if(isTRUE(include_zi)){ + (1-stats::rbinom(1, 1, mod_zi[x,j]))*stats::rpois(1, x2[i,j]) + } else { + stats::rpois(1, x2[i,j]) + }, + "zi_neg_binomial" = if(isTRUE(include_zi)){ + (1-stats::rbinom(1, 1, mod_zi[x,j]))*rgampois(1, x2[i,j], mod_kappa[x,j]) + } else { + rgampois(1, x2[i,j], mod_kappa[x,j]) + } ) } } diff --git a/R/pp_check.R b/R/pp_check.R index a32fe45..d4cfd4c 100644 --- a/R/pp_check.R +++ b/R/pp_check.R @@ -146,10 +146,13 @@ pp_check.jsdmStanFit <- function(object, plotfun = "dens_overlay", species = NUL # prepare plotting arguments ppc_args <- list(y = y, yrep = yrep) - for_pred <- union( - names(dots) %in% names(formals(jsdm_statsummary)), - names(dots) %in% names(formals(posterior_linpred.jsdmStanFit)) + for_pred <- names(dots) %in% union(union( + names(formals(jsdm_statsummary)), + names(formals(posterior_linpred.jsdmStanFit)) + ), + names(formals(posterior_predict.jsdmStanFit)) ) + ppc_args <- c(ppc_args, dots[!for_pred]) do.call(ppc_fun, ppc_args) diff --git a/R/prior.R b/R/prior.R index 4a49119..e1c1ce4 100644 --- a/R/prior.R +++ b/R/prior.R @@ -54,8 +54,9 @@ #' to be positive (default standard normal) #' @param kappa For negative binomial response, the negative binomial variance #' parameter. Constrained to be positive (default standard normal) -#' @param zi For zero-inflated poisson, the proportion of inflated zeros (default -#' beta distribution with both alpha and beta parameters set to 1). +#' @param zi For zero-inflated poisson or negative binomial, the proportion of +#' inflated zeros (default beta distribution with both alpha and beta parameters +#' set to 1). #' #' @return An object of class \code{"jsdmprior"} taking the form of a named list #' @export @@ -110,7 +111,7 @@ print.jsdmprior <- function(x, ...) { rep("site_intercept", 3), rep("mglmm", 3), rep("gllvm", 3), - "gaussian", "neg_binomial","zero_inflated_poisson" + "gaussian", "neg_binomial","zero_inflation" ), Constraint = c( "lower=0", rep("none", 5), rep("lower=0", 2), diff --git a/R/sim_data_funs.R b/R/sim_data_funs.R index 9766f5f..70423b6 100644 --- a/R/sim_data_funs.R +++ b/R/sim_data_funs.R @@ -37,7 +37,8 @@ #' #' @param family is the response family, must be one of \code{"gaussian"}, #' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -#' \code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +#' \code{"bernoulli"}, \code{"zi_poisson"}, or +#' \code{"zi_neg_binomial"}. Regular expression #' matching is supported. #' #' @param method is the jSDM method to use, currently either \code{"gllvm"} or @@ -66,9 +67,13 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m beta_param = "unstruct", prior = jsdm_prior()) { response <- match.arg(family, c("gaussian", "neg_binomial", "poisson", - "bernoulli", "binomial", "zero_inflated_poisson")) + "bernoulli", "binomial", "zi_poisson", + "zi_neg_binomial")) site_intercept <- match.arg(site_intercept, c("none","ungrouped","grouped")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) + if(missing(method)){ + stop("method argument needs to be specified") + } if(site_intercept == "grouped"){ stop("Grouped site intercept not supported") } @@ -300,11 +305,20 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m match.fun(prior_func[["kappa"]][[1]]), prior_func[["kappa"]][[2]] )) - } else if (response == "zero_inflated_poisson") { + } else if (response == "zi_poisson") { + zi <- do.call( + match.fun(prior_func[["zi"]][[1]]), + prior_func[["zi"]][[2]] + ) + } else if (response == "zi_neg_binomial") { zi <- do.call( match.fun(prior_func[["zi"]][[1]]), prior_func[["zi"]][[2]] ) + kappa <- abs(do.call( + match.fun(prior_func[["kappa"]][[1]]), + prior_func[["kappa"]][[2]] + )) } # print(str(sigma)) @@ -326,7 +340,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "poisson" = stats::rpois(1, exp(mu_ij)), "bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij)), "binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij)), - "zero_inflated_poisson" = stats::rbinom(1, 1, zi[j])*stats::rpois(1, exp(mu_ij)) + "zi_poisson" = (1-stats::rbinom(1, 1, zi[j]))*stats::rpois(1, exp(mu_ij)), + "zi_neg_binomial" = (1-stats::rbinom(1, 1, zi[j]))*rgampois(1, mu = exp(mu_ij), scale = kappa[j]) ) } } @@ -372,9 +387,13 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m if(response == "neg_binomial"){ pars$kappa <- kappa } - if(response == "zero_inflated_poisson"){ + if(response == "zi_poisson"){ pars$zi <- zi } + if(response == "zi_neg_binomial"){ + pars$zi <- zi + pars$kappa <- kappa + } if (isTRUE(species_intercept)) { if (K > 0) { x <- x[, 2:ncol(x)] diff --git a/R/stan_jsdm.R b/R/stan_jsdm.R index 8089c14..9a07989 100644 --- a/R/stan_jsdm.R +++ b/R/stan_jsdm.R @@ -28,7 +28,7 @@ #' #' @param family is the response family, must be one of \code{"gaussian"}, #' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -#' \code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +#' \code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression #' matching is supported. #' #' @param species_intercept Whether the model should be fit with an intercept by @@ -106,7 +106,8 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth beta_param = "unstruct", Ntrials = NULL, save_data = TRUE, iter = 4000, log_lik = TRUE, ...) { family <- match.arg(family, c("gaussian", "bernoulli", "poisson", - "neg_binomial","binomial", "zero_inflated_poisson")) + "neg_binomial","binomial", "zi_poisson", + "zi_neg_binomial")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) stopifnot( @@ -343,17 +344,38 @@ validate_data <- function(Y, D, X, species_intercept, stop("Y matrix is not binary") } } else if (family %in% c("poisson", "neg_binomial", "binomial", - "zero_inflated_poisson")) { + "zi_poisson", "zi_neg_binomial")) { if (!any(apply(data_list$Y, 1:2, is.wholenumber))) { stop("Y matrix is not composed of integers") } } + # check to make sure no completely blank columns in Y + if(any(apply(data_list$Y, 2, function(x) all(x == 0)))){ + stop("Y contains an empty column, which cannot work for this model") + } + # Check if Ntrials is appropriate given if(identical(family, "binomial")) { data_list$Ntrials <- ntrials_check(data_list$Ntrials, data_list$N) } + # create zeros/non-zeros for zero-inflated poisson + if(grepl("zi_",family)) { + if(any(apply(data_list$Y, 2, min)>0)){ + stop("Zero-inflated distributions require zeros to be present in all Y values.") + } + data_list$N_zero <- colSums(data_list$Y==0) + data_list$N_nonzero <- colSums(data_list$Y>0) + data_list$Sum_nonzero <- sum(data_list$N_nonzero) + data_list$Sum_zero <- sum(data_list$N_zero) + data_list$Y_nz <- c(as.matrix(data_list$Y))[c(as.matrix(data_list$Y))>0] + data_list$nn <- rep(1:data_list$N,data_list$S)[c(data_list$Y>0)] + data_list$ss <- rep(1:data_list$S,each=data_list$N)[c(data_list$Y>0)] + data_list$nz <- rep(1:data_list$N,data_list$S)[c(data_list$Y==0)] + data_list$sz <- rep(1:data_list$S,each=data_list$N)[c(data_list$Y==0)] + } + return(data_list) } diff --git a/man/jsdm_prior.Rd b/man/jsdm_prior.Rd index 7d415b8..d870dac 100644 --- a/man/jsdm_prior.Rd +++ b/man/jsdm_prior.Rd @@ -70,8 +70,9 @@ to be positive (default standard normal)} \item{kappa}{For negative binomial response, the negative binomial variance parameter. Constrained to be positive (default standard normal)} -\item{zi}{For zero-inflated poisson, the proportion of inflated zeros (default -beta distribution with both alpha and beta parameters set to 1).} +\item{zi}{For zero-inflated poisson or negative binomial, the proportion of +inflated zeros (default beta distribution with both alpha and beta parameters +set to 1).} \item{x}{Object of class \code{jsdmprior}} diff --git a/man/jsdm_sim_data.Rd b/man/jsdm_sim_data.Rd index efafa0b..05eaecc 100644 --- a/man/jsdm_sim_data.Rd +++ b/man/jsdm_sim_data.Rd @@ -35,7 +35,8 @@ mglmm_sim_data(...) \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +\code{"bernoulli"}, \code{"zi_poisson"}, or +\code{"zi_neg_binomial"}. Regular expression matching is supported.} \item{method}{is the jSDM method to use, currently either \code{"gllvm"} or diff --git a/man/jsdm_stancode.Rd b/man/jsdm_stancode.Rd index fb0d5e2..4800fbe 100644 --- a/man/jsdm_stancode.Rd +++ b/man/jsdm_stancode.Rd @@ -21,7 +21,8 @@ jsdm_stancode( \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +\code{"bernoulli"}, \code{"zi_poisson"}, or +\code{"zi_neg_binomial"}. Regular expression matching is supported.} \item{prior}{The prior, given as the result of a call to \code{\link[=jsdm_prior]{jsdm_prior()}}} diff --git a/man/jsdmstan-package.Rd b/man/jsdmstan-package.Rd index cc7f6bc..ef481b4 100644 --- a/man/jsdmstan-package.Rd +++ b/man/jsdmstan-package.Rd @@ -1,6 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/jsdmstan-package.R -\docType{package} +% Please edit documentation in R/jsdmstan_PACKAGE.R \name{jsdmstan-package} \alias{jsdmstan-package} \alias{jsdmstan} diff --git a/man/posterior_predict.jsdmStanFit.Rd b/man/posterior_predict.jsdmStanFit.Rd index ca501f3..59d2f27 100644 --- a/man/posterior_predict.jsdmStanFit.Rd +++ b/man/posterior_predict.jsdmStanFit.Rd @@ -13,6 +13,7 @@ draw_ids = NULL, list_index = "draws", Ntrials = NULL, + include_zi = TRUE, ... ) } @@ -33,20 +34,23 @@ number of samples.} \item{list_index}{Whether to return the output list indexed by the number of draws (default), species, or site.} -\item{Ntrials}{For the binomial distribution the number of trials, given as -either a single integer which is assumed to be constant across sites or as -a site-length vector of integers.} +\item{Ntrials}{For the binomial distribution the number of trials, given as either +a single integer which is assumed to be constant across sites or as a site-length +vector of integers.} + +\item{include_zi}{For the zero-inflated poisson distribution, whether to include +the zero-inflation in the prediction. Defaults to \code{TRUE}.} \item{...}{Currently unused} } \value{ A list of linear predictors. If list_index is \code{"draws"} (the default) -the list will have length equal to the number of draws with each element of -the list being a site x species matrix. If the list_index is \code{"species"} the -list will have length equal to the number of species with each element of -the list being a draws x sites matrix. If the list_index is \code{"sites"} the -list will have length equal to the number of sites with each element of -the list being a draws x species matrix. +the list will have length equal to the number of draws with each element of the +list being a site x species matrix. If the list_index is \code{"species"} the +list will have length equal to the number of species with each element of the +list being a draws x sites matrix. If the list_index is \code{"sites"} the list +will have length equal to the number of sites with each element of the list being +a draws x species matrix. } \description{ Draw from the posterior predictive distribution of the outcome. diff --git a/man/stan_gllvm.Rd b/man/stan_gllvm.Rd index 9755ab6..9578b01 100644 --- a/man/stan_gllvm.Rd +++ b/man/stan_gllvm.Rd @@ -44,7 +44,7 @@ example of how this can be formatted.} \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +\code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential diff --git a/man/stan_jsdm.Rd b/man/stan_jsdm.Rd index 93c3f5e..674bb3a 100644 --- a/man/stan_jsdm.Rd +++ b/man/stan_jsdm.Rd @@ -49,7 +49,7 @@ example of how this can be formatted.} \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +\code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential diff --git a/man/stan_mglmm.Rd b/man/stan_mglmm.Rd index 16ad027..05f8413 100644 --- a/man/stan_mglmm.Rd +++ b/man/stan_mglmm.Rd @@ -41,7 +41,7 @@ example of how this can be formatted.} \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -\code{"bernoulli"}, or \code{"zero_inflated_poisson"}. Regular expression +\code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential From f324279f569aabda02eb1de8f7145c0ff6e646c4 Mon Sep 17 00:00:00 2001 From: Fiona Seaton Date: Tue, 13 Aug 2024 16:34:01 +0100 Subject: [PATCH 3/5] Update tests and fix D1 issues --- DESCRIPTION | 3 +- R/posterior_predict.R | 7 +++- R/sim_data_funs.R | 13 ++++-- R/stan_jsdm.R | 7 ++++ man/stan_jsdm.Rd | 7 ++++ tests/testthat/test-posterior_predict.R | 54 ++++++++++++++++++++++++- tests/testthat/test-sim_data_funs.R | 13 +++++- 7 files changed, 97 insertions(+), 7 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index d272ba4..b8b0784 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -12,7 +12,7 @@ License: GPL (>= 3) Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Biarch: true Depends: R (>= 3.4.0) @@ -33,4 +33,5 @@ Suggests: rmarkdown, ggplot2 Config/testthat/edition: 3 +Config/testthat/parallel: true VignetteBuilder: knitr diff --git a/R/posterior_predict.R b/R/posterior_predict.R index 908af1f..2806862 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -111,7 +111,12 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, model_pred_list <- lapply(seq_along(draw_id), function(d) { if (method == "gllvm") { if (orig_data_used) { - LV_sum <- t((model_est$Lambda[d, , ] * model_est$sigma_L[d]) %*% model_est$LV[d, , ]) + if(object$n_latent>1){ + LV_sum <- t((model_est$Lambda[d, , ] * model_est$sigma_L[d]) %*% model_est$LV[d, , ]) + } else{ + LV_sum <- t((matrix(model_est$Lambda[d, , ], ncol = 1) * model_est$sigma_L[d]) + %*% matrix(model_est$LV[d, , ], nrow = 1)) + } } else { LV_sum <- 0 } diff --git a/R/sim_data_funs.R b/R/sim_data_funs.R index 70423b6..bb20fe3 100644 --- a/R/sim_data_funs.R +++ b/R/sim_data_funs.R @@ -269,9 +269,11 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m L <- matrix(nrow = S, ncol = D) idx2 <- 0 - for (i in 1:(D - 1)) { - for (j in (i + 1):(D)) { - L[i, j] <- 0 + if(D > 1){ + for (i in 1:(D - 1)) { + for (j in (i + 1):(D)) { + L[i, j] <- 0 + } } } for (j in 1:D) { @@ -346,6 +348,11 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m } } # print(str(Y)) + if(any(apply(Y, 2, function(x) all(x == 0)))){ + message(paste("Y contains an entirely empty column, which will not work for", + "jsdm fitting, it is recommended that the simulation is run again.")) + } + pars <- list( diff --git a/R/stan_jsdm.R b/R/stan_jsdm.R index 9a07989..f5a899d 100644 --- a/R/stan_jsdm.R +++ b/R/stan_jsdm.R @@ -11,6 +11,13 @@ #' \code{"unstruct"} parameterisation all covariate effects are assumed to draw #' from a simple distribution with no correlation structure. Both parameterisations #' can be modified using the prior object. +#' Families supported are the Gaussian family, the negative binomial family, +#' the Poisson family, the binomial family (with number of trials specificied +#' using the \code{Ntrials} parameter), the Bernoulli family (the special case +#' of the binomial family where number of trials is equal to one), the +#' zero-inflated Poisson and the zero-inflated negative binomial. For both +#' zero-inflated families the zero-inflation is assumed to be a species-specific +#' constant. #' #' @param formula The formula of covariates that the species means are modelled from #' diff --git a/man/stan_jsdm.Rd b/man/stan_jsdm.Rd index 674bb3a..7c2d468 100644 --- a/man/stan_jsdm.Rd +++ b/man/stan_jsdm.Rd @@ -101,6 +101,13 @@ to be constrained by a correlation matrix between the covariates. With the \code{"unstruct"} parameterisation all covariate effects are assumed to draw from a simple distribution with no correlation structure. Both parameterisations can be modified using the prior object. +Families supported are the Gaussian family, the negative binomial family, +the Poisson family, the binomial family (with number of trials specificied +using the \code{Ntrials} parameter), the Bernoulli family (the special case +of the binomial family where number of trials is equal to one), the +zero-inflated Poisson and the zero-inflated negative binomial. For both +zero-inflated families the zero-inflation is assumed to be a species-specific +constant. } \section{Methods (by class)}{ \itemize{ diff --git a/tests/testthat/test-posterior_predict.R b/tests/testthat/test-posterior_predict.R index 6fd3697..cb0e4fc 100644 --- a/tests/testthat/test-posterior_predict.R +++ b/tests/testthat/test-posterior_predict.R @@ -128,7 +128,7 @@ bino_sim_data <- gllvm_sim_data(N = 100, S = 9, K = 2, family = "binomial", Ntrials = 20) bino_pred_data <- matrix(rnorm(100 * 2), nrow = 100) colnames(bino_pred_data) <- c("V1", "V2") -suppressWarnings(bino_fit <- stan_mglmm( +suppressWarnings(bino_fit <- stan_gllvm( dat_list = bino_sim_data, family = "binomial", refresh = 0, chains = 2, iter = 500 )) @@ -150,3 +150,55 @@ test_that("posterior_(lin)pred works with gllvm and bino", { expect_false(any(sapply(bino_pred2, function(x) x < 0))) expect_false(any(sapply(bino_pred2, function(x) x > 16))) }) + +set.seed(86738873) +zip_sim_data <- gllvm_sim_data(N = 100, S = 7, K = 2, family = "zi_poisson", + site_intercept = "ungrouped", D = 1) +zip_pred_data <- matrix(rnorm(100 * 2), nrow = 100) +colnames(zip_pred_data) <- c("V1", "V2") +suppressWarnings(zip_fit <- stan_gllvm( + dat_list = zip_sim_data, family = "zi_poisson", + refresh = 0, chains = 2, iter = 500 +)) +test_that("posterior_(lin)pred works with gllvm and zip", { + zip_pred <- posterior_predict(zip_fit, ndraws = 100) + + expect_length(zip_pred, 100) + expect_false(any(sapply(zip_pred, anyNA))) + expect_false(any(sapply(zip_pred, function(x) x < 0))) + + zip_pred2 <- posterior_predict(zip_fit, + newdata = zip_pred_data, + ndraws = 50, list_index = "species" + ) + + expect_length(zip_pred2, 7) + expect_false(any(sapply(zip_pred2, anyNA))) + expect_false(any(sapply(zip_pred2, function(x) x < 0))) +}) + +set.seed(9598098) +zinb_sim_data <- mglmm_sim_data(N = 100, S = 7, K = 2, family = "zi_neg_binomial", + site_intercept = "ungrouped") +zinb_pred_data <- matrix(rnorm(100 * 2), nrow = 100) +colnames(zinb_pred_data) <- c("V1", "V2") +suppressWarnings(zinb_fit <- stan_mglmm( + dat_list = zinb_sim_data, family = "zi_neg_binomial", + refresh = 0, chains = 2, iter = 500 +)) +test_that("posterior_(lin)pred works with gllvm and zinb", { + zinb_pred <- posterior_predict(zinb_fit, ndraws = 100) + + expect_length(zinb_pred, 100) + expect_false(any(sapply(zinb_pred, anyNA))) + expect_false(any(sapply(zinb_pred, function(x) x < 0))) + + zinb_pred2 <- posterior_predict(zinb_fit, + newdata = zinb_pred_data, + ndraws = 50, list_index = "species" + ) + + expect_length(zinb_pred2, 7) + expect_false(any(sapply(zinb_pred2, anyNA))) + expect_false(any(sapply(zinb_pred2, function(x) x < 0))) +}) diff --git a/tests/testthat/test-sim_data_funs.R b/tests/testthat/test-sim_data_funs.R index 10793ac..2e3356f 100644 --- a/tests/testthat/test-sim_data_funs.R +++ b/tests/testthat/test-sim_data_funs.R @@ -87,6 +87,12 @@ test_that("mglmm_sim_data returns a list of correct length", { "Y", "pars", "N", "S", "D", "K", "X", "Ntrials" )) expect_length(gllvm_sim$Ntrials, 100) + gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "zi_neg_binomial", method = "gllvm", + Ntrials = 19) + expect_named(gllvm_sim, c( + "Y", "pars", "N", "S", "D", "K", "X" + )) + expect_equal(dim(gllvm_sim$Y),c(100,12)) }) test_that("jsdm_sim_data returns all appropriate pars", { @@ -103,7 +109,12 @@ test_that("jsdm_sim_data returns all appropriate pars", { "betas","a_bar","sigma_a","a","L","LV","sigma_L","kappa" )) - + gllvm_sim2 <- jsdm_sim_data(100,12,D=2,family = "zi_poisson", method = "gllvm", + beta_param = "unstruct", + site_intercept = "ungrouped") + expect_named(gllvm_sim2$pars, c( + "betas","a_bar","sigma_a","a","L","LV","sigma_L","zi" + )) }) From 5159b7059fa504c7eec9c82eb07bda34417fa29b Mon Sep 17 00:00:00 2001 From: Fiona Seaton Date: Wed, 21 Aug 2024 09:53:59 +0100 Subject: [PATCH 4/5] Update zi to covariate response Also creation of jsdmStanFamily class to support this, and associated changes to accessory functions and documentation --- DESCRIPTION | 1 - NAMESPACE | 2 + R/jsdm_stancode.R | 104 +++++++++++--- R/jsdmstan-families.R | 66 +++++++++ R/jsdmstanfit-class.R | 5 +- R/loo.R | 3 +- R/posterior_predict.R | 182 +++++++++++++++++++----- R/prior.R | 16 ++- R/sim_data_funs.R | 100 ++++++++++--- R/stan_jsdm.R | 73 +++++++++- R/update.R | 24 +++- man/jsdmStanFamily.Rd | 42 ++++++ man/jsdmStanFit.Rd | 2 +- man/jsdm_prior.Rd | 10 +- man/jsdm_sim_data.Rd | 13 ++ man/jsdm_stancode.Rd | 7 +- man/loo.jsdmStanFit.Rd | 3 +- man/posterior_linpred.jsdmStanFit.Rd | 4 - man/posterior_predict.jsdmStanFit.Rd | 4 - man/posterior_zipred.Rd | 52 +++++++ man/print.jsdmStanFamily.Rd | 16 +++ man/stan_jsdm.Rd | 11 ++ man/update.jsdmStanFit.Rd | 6 + tests/testthat/test-posterior_predict.R | 13 +- tests/testthat/test-sim_data_funs.R | 20 ++- tests/testthat/test-update.R | 24 ++++ 26 files changed, 682 insertions(+), 121 deletions(-) create mode 100644 R/jsdmstan-families.R create mode 100644 man/jsdmStanFamily.Rd create mode 100644 man/posterior_zipred.Rd create mode 100644 man/print.jsdmStanFamily.Rd diff --git a/DESCRIPTION b/DESCRIPTION index b8b0784..04b159b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,5 +33,4 @@ Suggests: rmarkdown, ggplot2 Config/testthat/edition: 3 -Config/testthat/parallel: true VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index 2e7b848..8c528db 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,7 @@ S3method(plot,jsdmStanFit) S3method(posterior_linpred,jsdmStanFit) S3method(posterior_predict,jsdmStanFit) S3method(pp_check,jsdmStanFit) +S3method(print,jsdmStanFamily) S3method(print,jsdmStanFit) S3method(print,jsdmprior) S3method(print,jsdmstan_model) @@ -40,6 +41,7 @@ export(nuts_params) export(ordiplot) export(posterior_linpred) export(posterior_predict) +export(posterior_zipred) export(pp_check) export(rgampois) export(rhat) diff --git a/R/jsdm_stancode.R b/R/jsdm_stancode.R index 6aeed29..a4efe53 100644 --- a/R/jsdm_stancode.R +++ b/R/jsdm_stancode.R @@ -27,6 +27,9 @@ #' grouping) #' @param beta_param The parameterisation of the environmental covariate effects, by #' default \code{"cor"}. See details for further information. +#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter +#' is a species-specific constant (default, \code{"constant"}), or varies by +#' environmental covariates (\code{"covariate"}). #' #' @return A character vector of Stan code, class "jsdmstan_model" #' @export @@ -37,7 +40,7 @@ #' jsdm_stancode <- function(method, family, prior = jsdm_prior(), log_lik = TRUE, site_intercept = "none", - beta_param = "cor") { + beta_param = "cor", zi_param = "constant") { # checks family <- match.arg(family, c("gaussian", "bernoulli", "poisson", "neg_binomial","binomial","zi_poisson", @@ -45,6 +48,7 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), method <- match.arg(method, c("gllvm", "mglmm")) beta_param <- match.arg(beta_param, c("cor","unstruct")) site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped")) + zi_param <- match.arg(zi_param, c("constant","covariate")) if (class(prior)[1] != "jsdmprior") { stop("Prior must be given as a jsdmprior object") } @@ -53,7 +57,7 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), scode <- .modelcode( method = method, family = family, phylo = FALSE, prior = prior, log_lik = log_lik, site_intercept = site_intercept, - beta_param = beta_param + beta_param = beta_param, zi_param = zi_param ) class(scode) <- c("jsdmstan_model", "character") return(scode) @@ -61,7 +65,7 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), .modelcode <- function(method, family, phylo, prior, log_lik, site_intercept, - beta_param) { + beta_param, zi_param) { model_functions <- " " data <- paste( @@ -101,7 +105,13 @@ ifelse(site_intercept == "grouped", int ss[Sum_nonzero]; //species index for Y_nz int nn[Sum_nonzero]; //site index for Y_nz int sz[Sum_zero]; //species index for Y_z - int nz[Sum_zero]; //site index for Y_z","")) + int nz[Sum_zero]; //site index for Y_z",""), +ifelse(grepl("zi_", family) & zi_param == "covariate"," + int zi_k; //number of covariates for env effects on zi + matrix[N, zi_k] zi_X; //environmental covariate matrix for zi","") +) + + transformed_data <- ifelse(method == "gllvm", " // Ensures identifiability of the model - no rotation of factors int M; @@ -147,11 +157,18 @@ ifelse(site_intercept == "grouped", "neg_binomial" = " real kappa[S]; // neg_binomial parameters", "poisson" = "", - "zi_poisson" = " + "zi_poisson" = switch(zi_param, + "constant" = " + real zi[S]; // zero-inflation parameter", + "covariate" = " + matrix[zi_k,S] zi_betas; //environmental effects for zi"), + "zi_neg_binomial" = switch(zi_param, + "constant" = " + real kappa[S]; // neg_binomial parameters real zi[S]; // zero-inflation parameter", - "zi_neg_binomial" = " + "covariate" = " real kappa[S]; // neg_binomial parameters - real zi[S]; // zero-inflation parameter" + matrix[zi_k,S] zi_betas; //environmental effects for zi") ) pars <- paste( @@ -235,22 +252,32 @@ ifelse(site_intercept == "grouped", ") model <- paste(" matrix[N,S] mu; - ", ifelse(grepl("zi_",family)," + ", ifelse(grepl("zi_",family),paste0(" real mu_nz[Sum_nonzero]; real mu_z[Sum_zero]; int pos; - int neg;",""), + int neg;",switch(zi_param,"constant" = "", + "covariate" = " + real zi_nz[Sum_nonzero]; + real zi_z[Sum_zero];")),""), switch(method, "gllvm" = gllvm_model, "mglmm" = mglmm_model - ),ifelse(grepl("zi_",family)," + ),ifelse(grepl("zi_",family),paste0(ifelse(zi_param == "covariate", " + matrix[N,S] zi = zi_X * zi_betas;","")," for(i in 1:Sum_nonzero){ - mu_nz[i] = mu[nn[i],ss[i]]; + mu_nz[i] = mu[nn[i],ss[i]];", + switch(zi_param, "constant" = "", + "covariate" = " + zi_nz[i] = zi[nn[i],ss[i]];")," } for(i in 1:Sum_zero){ - mu_z[i] = mu[nz[i],sz[i]]; + mu_z[i] = mu[nz[i],sz[i]];", + switch(zi_param, "constant" = "", + "covariate" = " + zi_z[i] = zi[nz[i],sz[i]];")," } - ","")) + "),"")) model_priors <- paste( ifelse(site_intercept %in% c("ungrouped","grouped"), paste(" // Site-level intercept priors @@ -296,17 +323,24 @@ ifelse(site_intercept == "grouped", "bern" = "", "poisson" = "", "binomial" = "", - "zi_poisson" = paste(" + "zi_poisson" = switch(zi_param,"constant" = paste(" //zero-inflation parameter zi ~ ", prior[["zi"]], "; -"), -"zi_neg_binomial" = paste(" +"), "covariate" = paste(" + //zero-inflation parameter + to_vector(zi_betas) ~ ", prior[["zi_betas"]], "; +")), +"zi_neg_binomial" = switch(zi_param, "constant" = paste(" //zero-inflation parameter zi ~ ", prior[["zi"]], "; kappa ~ ", prior[["kappa"]], "; +"), "covariate" = paste(" + //zero-inflation parameter + to_vector(zi_betas) ~ ", prior[["zi_betas"]], "; + kappa ~ ", prior[["kappa"]], "; ") ) - ) + )) model_pt2 <- if(!grepl("zi_", family)){ paste( " for(i in 1:N) Y[i,] ~ ", @@ -323,13 +357,19 @@ ifelse(site_intercept == "grouped", for(s in 1:S){ target += N_zero[s] - * log_sum_exp(log(zi[s]), + * log_sum_exp(", + switch(zi_param,"constant" = "log(zi[s]), log1m(zi[s]) +", + "covariate" = "bernoulli_logit_lpmf(1 | segment(zi_z, neg, N_zero[s])), + bernoulli_logit_lpmf(0 | segment(zi_z, neg, N_zero[s])) + +"), switch(family, "zi_poisson" = "poisson_log_lpmf(0 | segment(mu_z, neg, N_zero[s])));", "zi_neg_binomial" = "neg_binomial_2_log_lpmf(0 | segment(mu_z, neg, N_zero[s]), kappa[s]));")," - target += N_nonzero[s] * log1m(zi[s]); + target += N_nonzero[s] * ",switch(zi_param, + "constant" = "log1m(zi[s]);", + "covariate" = "bernoulli_logit_lpmf(0 | segment(zi_nz, pos, N_nonzero[s]));")," target +=", switch(family, "zi_poisson" = "poisson_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) | @@ -362,7 +402,9 @@ ifelse(site_intercept == "grouped", }", ""), ifelse(isTRUE(log_lik), paste( " { - matrix[N, S] linpred;", switch(site_intercept, "ungrouped" = paste(" + matrix[N, S] linpred;",ifelse(grepl("zi", family) & zi_param == "covariate"," + matrix[N,S] zi = zi_X * zi_betas;",""), + switch(site_intercept, "ungrouped" = paste(" linpred = rep_matrix(a_bar + a * sigma_a, S) + (X * betas) +", switch(method, "gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'", @@ -393,7 +435,7 @@ ifelse(site_intercept == "grouped", "neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);", "poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);", "binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);", - "zi_poisson" = "if (Y[i,j] == 0){ + "zi_poisson" = switch(zi_param,"constant" = "if (Y[i,j] == 0){ log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), bernoulli_lpmf(0 |zi[j]) + poisson_log_lpmf(Y[i,j] | linpred[i,j])); @@ -401,14 +443,30 @@ ifelse(site_intercept == "grouped", log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + poisson_log_lpmf(Y[i,j] | linpred[i,j]); }", - "zi_neg_binomial" = "if (Y[i,j] == 0){ + "covariate" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]), + bernoulli_logit_lpmf(0 |zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j])); + } else { + log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }"), + "zi_neg_binomial" = switch(zi_param,"constant" = "if (Y[i,j] == 0){ log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), bernoulli_lpmf(0 |zi[j]) + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j])); } else { log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]); - }" + }", + "covariate" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]), + bernoulli_logit_lpmf(0 |zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j])); + } else { + log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }") )," } } diff --git a/R/jsdmstan-families.R b/R/jsdmstan-families.R new file mode 100644 index 0000000..2d2f56f --- /dev/null +++ b/R/jsdmstan-families.R @@ -0,0 +1,66 @@ +#' jsdmStanFamily class +#' +#' This is the jsdmStanFamily class, which occupies a slot within any +#' jsdmStanFit object. +#' +#' @name jsdmStanFamily +#' +#' @section Elements for \code{jsdmStanFamily} objects: +#' \describe{ +#' \item{\code{family}}{ +#' A length one character vector describing family used to fit object. Options +#' are \code{"gaussian"}, \code{"poisson"}, \code{"bernoulli"}, +#' \code{"neg_binomial"}, \code{"binomial"}, \code{"zi_poisson"}, +#' \code{"zi_neg_binomial"}, or \code{"multiple"}. +#' } +#' \item{\code{params}}{ +#' A character vector that includes all the names of the family-specific parameters. +#' } +#' \item{\code{params_dataresp}}{ +#' A character vector that includes any named family-specific parameters that are +#' modelled in response to data. +#' } +#' \item{\code{preds}}{ +#' A character vector of the measured predictors included if family parameters +#' are modelled in response to data. If family parameters are not modelled in +#' response to data this is left empty. +#' } +#' \item{\code{data_list}}{ +#' A list containing the original data used to fit the model +#' (empty when save_data is set to \code{FALSE} or family parameters are not +#' modelled in response to data). +#' } +#' } +#' +jsdmStanFamily_empty <- function(){ + res <- list(family = character(), + params = character(), + params_dataresp= character(), + preds = character(), + data_list = list()) + class(res) <- "jsdmStanFamily" + return(res) +} + +# jsdmStanFamily methods + +#' Print jsdmStanFamily object +#' +#' @param x A jsdmStanFamily object +#' @param ... Other arguments, not used at this stage. +#' +#' @export +print.jsdmStanFamily <- function(x, ...){ + cat(paste("Family:", x$family, "\n", + ifelse(length(x$params)>0, + paste("With parameters:", + paste0(x$params, sep = ", "),"\n"), + ""))) + if(length(x$params_dataresp)>0){ + cat(paste("Family-specific parameter", + paste0(x$params_dataresp,sep=", "), + "is modelled in response to", length(x$preds), + "predictors. These are named:", + paste0(x$preds, sep = ", "))) + } +} diff --git a/R/jsdmstanfit-class.R b/R/jsdmstanfit-class.R index fa0a4be..d364211 100644 --- a/R/jsdmstanfit-class.R +++ b/R/jsdmstanfit-class.R @@ -10,7 +10,7 @@ #' A length one character vector describing type of jSDM #' } #' \item{\code{family}}{ -#' A character vector describing response family +#' A jsdmStanFamily object describing characteristics of family #' } #' \item{\code{species}}{ #' A character vector of the species names @@ -35,7 +35,7 @@ jsdmStanFit_empty <- function() { res <- list( jsdm_type = "None", - family = character(), + family = jsdmStanFamily_empty(), species = character(), sites = character(), preds = character(), @@ -77,6 +77,7 @@ print.jsdmStanFit <- function(x, ...) { " Number of species: ", length(x$species), "\n", " Number of sites: ", length(x$sites), "\n", " Number of predictors: ", length(x$preds), "\n", + print(x$family), "\n", "Model run on ", length(x$fit@stan_args), " chains with ", x$fit@stan_args[[1]]$iter, " iterations per chain (", diff --git a/R/loo.R b/R/loo.R index a91b14a..754241d 100644 --- a/R/loo.R +++ b/R/loo.R @@ -2,7 +2,8 @@ #' #' This function uses the \pkg{loo} package to compute PSIS-LOO CV, efficient #' approximate leave-one-out (LOO) cross-validation for Bayesian models using Pareto -#' smoothed importance sampling (PSIS). +#' smoothed importance sampling (PSIS). This requires that the model was fit using +#' \code{log_lik = TRUE}. #' #' @param x The jsdmStanFit model object #' @param ... Other arguments passed to the \code{\link[loo]{loo}} function diff --git a/R/posterior_predict.R b/R/posterior_predict.R index 2806862..5fd6303 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -19,9 +19,6 @@ #' #' @param draw_ids The IDs of the draws to be used, as a numeric vector #' -#' @param newdata_type What form is the new data in, at the moment only -#' supplying covariates is supported. -#' #' @param list_index Whether to return the output list indexed by the number of #' draws (default), species, or site. #' @param ... Currently unused @@ -42,23 +39,15 @@ #' @export posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, newdata = NULL, ndraws = NULL, - draw_ids = NULL, newdata_type = "X", + draw_ids = NULL, list_index = "draws", ...) { - if (newdata_type != "X") { - stop("Currently only data on covariates is supported.") - } stopifnot(is.logical(transform)) - if (isTRUE(transform) & object$family == "gaussian") { + if (isTRUE(transform) & object$family$family == "gaussian") { warning("No inverse-link transform performed for Gaussian response models.") } - if (!is.null(ndraws) & !is.null(draw_ids)) { - message("Both ndraws and draw_ids have been specified, ignoring ndraws") - } - if (!is.null(draw_ids)) { - if (any(!is.wholenumber(draw_ids))) { - stop("draw_ids must be a vector of positive integers") - } - } + + foopred_checks(object = object, transform = transform, newdata = newdata, + ndraws = ndraws, draw_ids = draw_ids, list_index = list_index) list_index <- match.arg(list_index, c("draws", "species", "sites")) @@ -69,19 +58,12 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, method <- object$jsdm_type orig_data_used <- is.null(newdata) - if (is.null(newdata) & length(object$data_list) == 0) { - stop(paste( - "Original data must be included in model object if no new data", - "is provided." - )) - } if (is.null(newdata)) { newdata <- object$data_list$X } newdata <- validate_newdata(newdata, - preds = object$preds, - newdata_type = newdata_type + preds = object$preds ) model_pars <- "betas" @@ -141,7 +123,7 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, ) if (isTRUE(transform)) { mu <- apply(mu, 1:2, function(x) { - switch(object$family, + switch(object$family$family, "gaussian" = x, "bernoulli" = inv_logit(x), "poisson" = exp(x), @@ -192,12 +174,12 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, #'@export posterior_predict #'@export posterior_predict.jsdmStanFit <- function(object, newdata = NULL, - newdata_type = "X", ndraws = NULL, + ndraws = NULL, draw_ids = NULL, list_index = "draws", Ntrials = NULL, include_zi = TRUE, ...) { - transform <- ifelse(object$family == "gaussian", FALSE, TRUE) + transform <- ifelse(object$family$family == "gaussian", FALSE, TRUE) if (!is.null(ndraws) & !is.null(draw_ids)) { message("Both ndraws and draw_ids have been specified, ignoring ndraws") } @@ -213,32 +195,39 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, post_linpred <- posterior_linpred(object, newdata = newdata, - newdata_type = newdata_type, draw_ids = draw_id, + draw_ids = draw_id, transform = transform, list_index = "draws" ) - if (object$family == "gaussian") { + if (object$family$family == "gaussian") { mod_sigma <- extract(object, pars = "sigma")[[1]][draw_id,] - } else if (object$family == "neg_binomial") { - mod_kappa <- extract(object, pars = "kappa")[[1]][draw_id,] - } else if(object$family == "binomial"){ + } else if(object$family$family == "binomial"){ if(is.null(newdata)) { Ntrials <- object$data_list$Ntrials } else { Ntrials <- ntrials_check(Ntrials, nrow(newdata)) } - } else if(object$family == "zi_poisson"){ - mod_zi <- extract(object, pars = "zi")[[1]][draw_id,] - } else if(object$family == "zi_neg_binomial"){ + } + if(grepl("zi",object$family$family) & !("zi" %in% object$family$params_dataresp) & isTRUE(include_zi)){ mod_zi <- extract(object, pars = "zi")[[1]][draw_id,] + } + if(grepl("neg_binomial",object$family$family)) { mod_kappa <- extract(object, pars = "kappa")[[1]][draw_id,] } + if("zi" %in% object$family$params_dataresp & isTRUE(include_zi)){ + post_zipred <- posterior_zipred(object, + newdata = newdata, + draw_ids = draw_id, + transform = transform, list_index = "draws" + ) + } + n_sites <- length(object$sites) n_species <- length(object$species) post_pred <- lapply(seq_along(post_linpred), - function(x, family = object$family) { + function(x, family = object$family$family) { x2 <- post_linpred[[x]] if(family == "binomial"){ for(i in 1:nrow(x2)){ @@ -246,11 +235,23 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, x2[i,j] <- stats::rbinom(1, Ntrials[i], x2[i,j]) } } + } else if (grepl("zi_",family) & "zi" %in% object$family$params_dataresp & + isTRUE(include_zi)){ + zi2 <- post_zipred[[x]] + for(i in seq_len(nrow(x2))){ + for(j in seq_len(ncol(x2))){ + x2[i,j] <- switch( + object$family$family, + "zi_poisson" = (1-stats::rbinom(1, 1, zi2[x,j]))*stats::rpois(1, x2[i,j]), + "zi_neg_binomial" = (1-stats::rbinom(1, 1, zi2[x,j]))*rgampois(1, x2[i,j], mod_kappa[x,j]) + ) + } + } } else { for(i in seq_len(nrow(x2))){ for(j in seq_len(ncol(x2))){ x2[i,j] <- switch( - object$family, + object$family$family, "gaussian" = stats::rnorm(1, x2[i,j], mod_sigma[x,j]), "bernoulli" = stats::rbinom(1, 1, x2[i,j]), "poisson" = stats::rpois(1, x2[i,j]), @@ -280,10 +281,93 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, } +#' Access the posterior distribution of the linear predictor for zero-inflation +#' parameter +#' +#' Extract the posterior draws of the linear predictor for the zero-inflation +#' parameter, possibly transformed by the inverse-link function. +#' +#' +#' @inheritParams posterior_linpred.jsdmStanFit +#' +#' @return A list of linear predictors. If list_index is \code{"draws"} (the +#' default) the list will have length equal to the number of draws with each +#' element of the list being a site x species matrix. If the list_index is +#' \code{"species"} the list will have length equal to the number of species +#' with each element of the list being a draws x sites matrix. If the +#' list_index is \code{"sites"} the list will have length equal to the number +#' of sites with each element of the list being a draws x species matrix. Note +#' that in the zero-inflated case this is only the linear predictor of the +#' non-zero-inflated part of the model. +#' +#' @seealso [posterior_predict.jsdmStanFit()] +#' +#' @export +posterior_zipred <- function(object, transform = FALSE, + newdata = NULL, ndraws = NULL, + draw_ids = NULL, + list_index = "draws"){ + if(!("zi" %in% object$family$params_dataresp)){ + stop(paste("This function only works upon models with a zero-inflated family", + "and where the zero-inflation is responsive to covariates.")) + } + foopred_checks(object = object, transform = transform, newdata = newdata, + ndraws = ndraws, draw_ids = draw_ids, list_index = list_index) + list_index <- match.arg(list_index, c("draws", "species", "sites")) + + n_sites <- length(object$sites) + n_species <- length(object$species) + n_preds <- length(object$family$preds) + method <- object$jsdm_type + orig_data_used <- is.null(newdata) + + if (is.null(newdata)) { + newdata <- object$family$data_list$zi_X + } + + newdata <- validate_newdata(newdata, + preds = object$family$preds + ) + + model_pars <- "zi_betas" + + model_est <- extract(object, pars = model_pars) + n_iter <- dim(model_est[[1]])[1] + + draw_id <- draw_id_check(draw_ids = draw_ids, n_iter = n_iter, ndraws = ndraws) + + model_est <- lapply(model_est, function(x) { + switch(length(dim(x)), + `1` = x[draw_id, drop = FALSE], + `2` = x[draw_id, , drop = FALSE], + `3` = x[draw_id, , , drop = FALSE] + ) + }) + + model_pred_list <- lapply(seq_along(draw_id), function(d) { + + if (is.vector(newdata)) { + newdata <- matrix(newdata, ncol = 1) + } + zi <- newdata %*% model_est$zi_betas[d, , ] + + if (isTRUE(transform)) { + zi <- apply(zi, c(1,2), inv_logit) + } + + return(zi) + }) + + if (list_index != "draws") { + model_pred_list <- switch_indices(model_pred_list, list_index) + } + + return(model_pred_list) +} # internal ~~~~ -validate_newdata <- function(newdata, preds, newdata_type) { +validate_newdata <- function(newdata, preds) { preds_nointercept <- preds[preds != "(Intercept)"] if (!all(preds_nointercept %in% colnames(newdata))) { @@ -295,7 +379,7 @@ validate_newdata <- function(newdata, preds, newdata_type) { newdata <- newdata[, preds_nointercept] if ("(Intercept)" %in% preds) { - newdata <- cbind(`(Intercept)` = 1, newdata) + newdata <- cbind("(Intercept)" = 1, newdata) newdata <- newdata[, preds] } return(newdata) @@ -341,3 +425,23 @@ draw_id_check <- function(draw_ids, n_iter, ndraws){ } return(draw_id) } + +foopred_checks <- function(object, transform, draw_ids, ndraws, list_index, + newdata){ + stopifnot(is.logical(transform)) + if (!is.null(ndraws) & !is.null(draw_ids)) { + message("Both ndraws and draw_ids have been specified, ignoring ndraws") + } + if (!is.null(draw_ids)) { + if (any(!is.wholenumber(draw_ids))) { + stop("draw_ids must be a vector of positive integers") + } + } + if (is.null(newdata) & length(object$data_list) == 0) { + stop(paste( + "Original data must be included in model object if no new data", + "is provided." + )) + } + return(NULL) +} diff --git a/R/prior.R b/R/prior.R index e1c1ce4..b643385 100644 --- a/R/prior.R +++ b/R/prior.R @@ -54,9 +54,13 @@ #' to be positive (default standard normal) #' @param kappa For negative binomial response, the negative binomial variance #' parameter. Constrained to be positive (default standard normal) -#' @param zi For zero-inflated poisson or negative binomial, the proportion of +#' @param zi For zero-inflated poisson or negative binomial with no environmental +#' covariate effects upon the zero-inflation, the proportion of #' inflated zeros (default beta distribution with both alpha and beta parameters #' set to 1). +#' @param zi_betas For zero-inflated poisson or negative binomial with +#' environmental effects upon the zero-inflation, the covariate effects on the +#' zero-inflation on the logit scale #' #' @return An object of class \code{"jsdmprior"} taking the form of a named list #' @export @@ -80,7 +84,8 @@ jsdm_prior <- function(sigmas_preds = "normal(0,1)", sigma_L = "normal(0,1)", sigma = "normal(0,1)", kappa = "normal(0,1)", - zi = "beta(1,1)") { + zi = "beta(1,1)", + zi_betas = "normal(0,1)") { res <- list( sigmas_preds = sigmas_preds, z_preds = z_preds, cor_preds = cor_preds, betas = betas, @@ -88,7 +93,8 @@ jsdm_prior <- function(sigmas_preds = "normal(0,1)", sigmas_species = sigmas_species, z_species = z_species, cor_species = cor_species, LV = LV, L = L, sigma_L = sigma_L, - sigma = sigma, kappa = kappa, zi = zi + sigma = sigma, kappa = kappa, zi = zi, + zi_betas = zi_betas ) if (!(all(sapply(res, is.character)))) { stop("All arguments must be supplied as character vectors") @@ -111,11 +117,11 @@ print.jsdmprior <- function(x, ...) { rep("site_intercept", 3), rep("mglmm", 3), rep("gllvm", 3), - "gaussian", "neg_binomial","zero_inflation" + "gaussian", "neg_binomial",rep("zero_inflation",2) ), Constraint = c( "lower=0", rep("none", 5), rep("lower=0", 2), - rep("none", 4), rep("lower=0", 3),"lower=0,upper=1" + rep("none", 4), rep("lower=0", 3),"lower=0,upper=1","none" ), Prior = unlist(unname(x)) ) diff --git a/R/sim_data_funs.R b/R/sim_data_funs.R index bb20fe3..2b52a2d 100644 --- a/R/sim_data_funs.R +++ b/R/sim_data_funs.R @@ -59,18 +59,31 @@ #' @param beta_param The parameterisation of the environmental covariate #' effects, by default \code{"unstruct"}. See details for further information. #' +#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter +#' is a species-specific constant (default, \code{"constant"}), or varies by +#' environmental covariates (\code{"covariate"}). +#' +#' @param zi_k If \code{zi="covariate"}, the number of environmental covariates +#' that the zero-inflation parameter responds to. The default (\code{NULL}) is +#' that the zero-inflation parameter responds to exactly the same covariate matrix +#' as the mean parameter. Otherwise, a different set of random environmental +#' covariates are generated, plus an intercept (not included in zi_k) and used +#' to predict zero-inflation +#' #' @param prior Set of prior specifications from call to [jsdm_prior()] jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "mglmm"), species_intercept = TRUE, Ntrials = NULL, site_intercept = "none", beta_param = "unstruct", + zi_param = "constant", zi_k = NULL, prior = jsdm_prior()) { response <- match.arg(family, c("gaussian", "neg_binomial", "poisson", "bernoulli", "binomial", "zi_poisson", "zi_neg_binomial")) site_intercept <- match.arg(site_intercept, c("none","ungrouped","grouped")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) + zi_param <- match.arg(zi_param, c("constant","covariate")) if(missing(method)){ stop("method argument needs to be specified") } @@ -105,6 +118,18 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m Ntrials <- ntrials_check(Ntrials = Ntrials, N = N) } + if (!is.null(zi_k)) { + if(zi_k < 1 | zi_k %% 1 != 0){ + stop("zi_k must be either NULL or a positive integer") + } + } + + if(is.null(zi_k)){ + ZI_K <- K + } else{ + ZI_K <- zi_k + } + # prior object breakdown prior_split <- lapply(prior, strsplit, split = "\\(|\\)|,") if (!all(sapply(prior_split, function(x) { @@ -158,7 +183,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "sigma_L" = 1, "sigma" = S, "kappa" = S, - "zi" = S + "zi" = S, + "zi_betas" = S*(ZI_K+1) ) fun_args <- as.list(c(fun_arg1, as.numeric(unlist(y[[1]][[1]])[-1]))) @@ -184,12 +210,12 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m # now do covariates - if K = NULL then do intercept only if (K == 0) { x <- matrix(1, nrow = N, ncol = 1) - colnames(x) <- "Intercept" + colnames(x) <- "(Intercept)" J <- 1 } else if (isTRUE(species_intercept)) { x <- matrix(stats::rnorm(N * K), ncol = K, nrow = N) colnames(x) <- paste0("V", 1:K) - x <- cbind(Intercept = 1, x) + x <- cbind("(Intercept)" = 1, x) J <- K + 1 } else if (isFALSE(species_intercept)) { x <- matrix(stats::rnorm(N * K), ncol = K, nrow = N) @@ -308,15 +334,29 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m prior_func[["kappa"]][[2]] )) } else if (response == "zi_poisson") { - zi <- do.call( - match.fun(prior_func[["zi"]][[1]]), - prior_func[["zi"]][[2]] - ) + if(zi_param == "covariate"){ + zi_betas <- matrix(do.call( + match.fun(prior_func[["zi_betas"]][[1]]), + prior_func[["zi_betas"]][[2]] + ), ncol = S) + } else { + zi <- do.call( + match.fun(prior_func[["zi"]][[1]]), + prior_func[["zi"]][[2]] + ) + } } else if (response == "zi_neg_binomial") { - zi <- do.call( - match.fun(prior_func[["zi"]][[1]]), - prior_func[["zi"]][[2]] - ) + if(zi_param == "covariate"){ + zi_betas <- matrix(do.call( + match.fun(prior_func[["zi_betas"]][[1]]), + prior_func[["zi_betas"]][[2]] + ), ncol = S) + } else { + zi <- do.call( + match.fun(prior_func[["zi"]][[1]]), + prior_func[["zi"]][[2]] + ) + } kappa <- abs(do.call( match.fun(prior_func[["kappa"]][[1]]), prior_func[["kappa"]][[2]] @@ -324,6 +364,19 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m } # print(str(sigma)) + # zero-inflation in case of covariates + if(grepl("zi_", response) & zi_param == "covariate"){ + if(is.null(zi_k)){ + zi_X <- x + } else { + zi_X <- matrix(stats::rnorm(N * zi_k), ncol = zi_k, nrow = N) + colnames(zi_X) <- paste0("V", 1:zi_k) + zi_X <- cbind("(Intercept)" = 1, zi_X) + } + zi <- inv_logit(zi_X %*% zi_betas) + } + + # generate Y Y <- matrix(nrow = N, ncol = S) for (i in 1:N) { for (j in 1:S) { @@ -342,12 +395,18 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "poisson" = stats::rpois(1, exp(mu_ij)), "bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij)), "binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij)), - "zi_poisson" = (1-stats::rbinom(1, 1, zi[j]))*stats::rpois(1, exp(mu_ij)), - "zi_neg_binomial" = (1-stats::rbinom(1, 1, zi[j]))*rgampois(1, mu = exp(mu_ij), scale = kappa[j]) + "zi_poisson" = (1-stats::rbinom( + 1, 1, + ifelse(zi_param == "covariate", + zi[i,j],zi[j])))*stats::rpois(1, exp(mu_ij)), + "zi_neg_binomial" = (1-stats::rbinom( + 1, 1, + ifelse(zi_param == "covariate", + zi[i,j],zi[j])))*rgampois(1, mu = exp(mu_ij), scale = kappa[j]) ) } } - # print(str(Y)) + if(any(apply(Y, 2, function(x) all(x == 0)))){ message(paste("Y contains an entirely empty column, which will not work for", "jsdm fitting, it is recommended that the simulation is run again.")) @@ -394,11 +453,14 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m if(response == "neg_binomial"){ pars$kappa <- kappa } - if(response == "zi_poisson"){ - pars$zi <- zi + if(grepl("zi_", response)){ + if(zi_param == "constant"){ + pars$zi <- zi + } else if(zi_param == "covariate"){ + pars$zi_betas <- zi_betas + } } if(response == "zi_neg_binomial"){ - pars$zi <- zi pars$kappa <- kappa } if (isTRUE(species_intercept)) { @@ -414,6 +476,10 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m if(response == "binomial"){ output$Ntrials <- Ntrials } + if(grepl("zi_", response) & zi_param == "covariate"){ + output$zi_k <- ZI_K + 1 + output$zi_X <- zi_X + } return(output) } diff --git a/R/stan_jsdm.R b/R/stan_jsdm.R index f5a899d..9d21625 100644 --- a/R/stan_jsdm.R +++ b/R/stan_jsdm.R @@ -71,6 +71,15 @@ #' @param beta_param The parameterisation of the environmental covariate effects, by #' default \code{"unstruct"}. See details for further information. #' +#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter +#' is a species-specific constant (default, \code{"constant"}), or varies by +#' environmental covariates (\code{"covariate"}). +#' +#' @param zi_X If \code{zi = "covariate"}, the matrix of environmental predictors +#' that the zero-inflation is modelled in response to. If there is not already +#' an intercept column (identified by all values being equal to one), one will +#' be added to the front of the matrix. +#' #' @param ... Arguments passed to [rstan::sampling()] #' #' @return A \code{jsdmStanFit} object, comprising a list including the StanFit @@ -111,11 +120,22 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth dat_list = NULL, family, site_intercept = "none", D = NULL, prior = jsdm_prior(), site_groups = NULL, beta_param = "unstruct", Ntrials = NULL, + zi_param = "constant", zi_X = NULL, save_data = TRUE, iter = 4000, log_lik = TRUE, ...) { family <- match.arg(family, c("gaussian", "bernoulli", "poisson", "neg_binomial","binomial", "zi_poisson", "zi_neg_binomial")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) + zi_param <- match.arg(zi_param, c("constant","covariate")) + if(grepl("zi", family) & zi_param == "covariate"){ + if(is.null(zi_X) & is.null(dat_list)){ + message("If zi_param = 'covariate' and no zi_X matrix is supplied then the X matrix is used") + zi_X <- X + } + } else{ + zi_X <- NULL + } + stopifnot( is.logical(species_intercept), @@ -129,7 +149,8 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth Y = Y, X = X, species_intercept = species_intercept, D = D, site_intercept = site_intercept, site_groups = site_groups, dat_list = dat_list, - family = family, method = method, Ntrials = Ntrials + family = family, method = method, Ntrials = Ntrials, + zi_X = zi_X ) # Create stancode @@ -137,7 +158,7 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth family = family, method = method, prior = prior, log_lik = log_lik, site_intercept = site_intercept, - beta_param = beta_param + beta_param = beta_param, zi_param = zi_param ) # Compile model @@ -235,7 +256,8 @@ stan_gllvm.formula <- function(formula, data = list(), ...) { validate_data <- function(Y, D, X, species_intercept, dat_list, family, site_intercept, - method, site_groups, Ntrials) { + method, site_groups, Ntrials, + zi_X) { method <- match.arg(method, c("gllvm", "mglmm")) # do things if data not given as list: @@ -273,6 +295,16 @@ validate_data <- function(Y, D, X, species_intercept, colnames(X)[1] <- "(Intercept)" } } + if(grepl("zi", family) & !is.null(zi_X)){ + if(is.null(colnames(zi_X))){ + message("No column names specified for zi_X, assigning names") + colnames(zi_X) <- paste0("V",seq_len(ncol(zi_X))) + } + if(!any(apply(zi_X, 2, function(x) all(x == 1)))){ + zi_X <- cbind("(Intercept)" = 1, zi_X) + } + zi_k <- ncol(zi_X) + } if(site_intercept == "grouped"){ if(length(site_groups) != N) @@ -299,6 +331,14 @@ validate_data <- function(Y, D, X, species_intercept, if(family == "binomial"){ data_list$Ntrials <- Ntrials } + if(grepl("zi_",family) & !is.null(zi_X)){ + data_list$zi_k <- zi_k + data_list$zi_X <- zi_X + + if(nrow(zi_X) != N){ + stop("Number of rows of zi_X must be equal to number of rows of X") + } + } } else { if (!all(c("Y", "K", "S", "N", "X") %in% names(dat_list))) { stop("If supplying data as a list must have entries Y, K, S, N, X") @@ -314,6 +354,12 @@ validate_data <- function(Y, D, X, species_intercept, } } + if (grepl("zi_", family) & !is.null(zi_X)) { + if (!all(c("zi_X","zi_k") %in% names(dat_list))) { + stop("Zero-inflated models with the covariate parameterisation of zi require zi_X and zi_k in dat_list") + } + } + if (site_intercept == "grouped") { if (!all(c("ngrp","grps") %in% names(dat_list))) { stop("Grouped site intercept models require ngrp and grps in dat_list") @@ -412,11 +458,30 @@ model_to_jsdmstanfit <- function(model_fit, method, data_list, species_intercept } else { list() } + fam <- list(family = family, + params = switch(family, + "gaussian" = "sigma", + "bernoulli" = character(), + "poisson" = character(), + "neg_binomial" = "kappa", + "binomial" = character(), + "zi_poisson" = "zi", + "zi_neg_binomial" = c("kappa","zi")), + params_dataresp= character(), + preds = character(), + data_list = list()) + class(fam) <- "jsdmStanFamily" + if(isTRUE(save_data) & ("zi_X" %in% names(data_list))){ + fam$params_dataresp <- "zi" + fam$preds <- colnames(data_list$zi_X) + fam$data_list <- list(zi_X = data_list$zi_X) + } + print(str(fam$preds)) model_output <- list( fit = model_fit, jsdm_type = method, - family = family, + family = fam, species = species, sites = sites, preds = preds, diff --git a/R/update.R b/R/update.R index e735119..183e3d4 100644 --- a/R/update.R +++ b/R/update.R @@ -12,6 +12,10 @@ #' @param newD New number of latent variables, by default \code{NULL} #' @param newNtrials New number of trials (binomial model only), by default #' \code{NULL} +#' @param newZi_X New predictor data for the zi parameter in zero-inflated models, +#' by default \code{NULL}. In cases where the model was originally fit with the +#' same X and zi_X data and only newX is supplied to update.jsdmStanFit the zi_X +#' data will also be set to newX. #' @param save_data Whether to save the data in the jsdmStanFit object, by default #' \code{TRUE} #' @param ... Arguments passed to [rstan::sampling()] @@ -51,7 +55,7 @@ #' gllvm_fit2 #' } update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, - newNtrials = NULL, + newNtrials = NULL, newZi_X = NULL, save_data = TRUE, ...) { if (length(object$data_list) == 0) { stop("Update requires the original data to be saved in the model object") @@ -72,7 +76,7 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, } else { Y <- newY } - family <- object$family + family <- object$family$family method <- object$jsdm_type if(!is.null(newD)){ D <- newD @@ -86,6 +90,19 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, Ntrials <- object$data_list$Ntrials } } + if ("zi" %in% object$family$params_dataresp){ + if(is.null(newZi_X)) { + if(isTRUE(all.equal(object$data_list$X, object$family$data_list$zi_X)) & !is.null(newX)){ + zi_X <- newX + } else{ + zi_X <- object$family$data_list$zi_X + } + } else { + zi_X <- newZi_X + } + } else{ + zi_X <- NULL + } species_intercept <- "(Intercept)" %in% colnames(object$data_list$X) @@ -100,7 +117,8 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, Y = Y, X = X, species_intercept = species_intercept, D = D, site_intercept = site_intercept, site_groups = site_groups, dat_list = NULL, - family = family, method = method, Ntrials = Ntrials + family = family, method = method, Ntrials = Ntrials, + zi_X = zi_X ) # get original stan model diff --git a/man/jsdmStanFamily.Rd b/man/jsdmStanFamily.Rd new file mode 100644 index 0000000..eaf6283 --- /dev/null +++ b/man/jsdmStanFamily.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/jsdmstan-families.R +\name{jsdmStanFamily} +\alias{jsdmStanFamily} +\alias{jsdmStanFamily_empty} +\title{jsdmStanFamily class} +\usage{ +jsdmStanFamily_empty() +} +\description{ +This is the jsdmStanFamily class, which occupies a slot within any +jsdmStanFit object. +} +\section{Elements for \code{jsdmStanFamily} objects}{ + +\describe{ +\item{\code{family}}{ +A length one character vector describing family used to fit object. Options +are \code{"gaussian"}, \code{"poisson"}, \code{"bernoulli"}, +\code{"neg_binomial"}, \code{"binomial"}, \code{"zi_poisson"}, +\code{"zi_neg_binomial"}, or \code{"multiple"}. +} +\item{\code{params}}{ +A character vector that includes all the names of the family-specific parameters. +} +\item{\code{params_dataresp}}{ +A character vector that includes any named family-specific parameters that are +modelled in response to data. +} +\item{\code{preds}}{ +A character vector of the measured predictors included if family parameters +are modelled in response to data. If family parameters are not modelled in +response to data this is left empty. +} +\item{\code{data_list}}{ +A list containing the original data used to fit the model +(empty when save_data is set to \code{FALSE} or family parameters are not +modelled in response to data). +} +} +} + diff --git a/man/jsdmStanFit.Rd b/man/jsdmStanFit.Rd index f8d2eae..0c66f0e 100644 --- a/man/jsdmStanFit.Rd +++ b/man/jsdmStanFit.Rd @@ -17,7 +17,7 @@ This is the jsdmStanFit class, which stan_gllvm and stan_mglmm both create. A length one character vector describing type of jSDM } \item{\code{family}}{ -A character vector describing response family +A jsdmStanFamily object describing characteristics of family } \item{\code{species}}{ A character vector of the species names diff --git a/man/jsdm_prior.Rd b/man/jsdm_prior.Rd index d870dac..d122f96 100644 --- a/man/jsdm_prior.Rd +++ b/man/jsdm_prior.Rd @@ -21,7 +21,8 @@ jsdm_prior( sigma_L = "normal(0,1)", sigma = "normal(0,1)", kappa = "normal(0,1)", - zi = "beta(1,1)" + zi = "beta(1,1)", + zi_betas = "normal(0,1)" ) \method{print}{jsdmprior}(x, ...) @@ -70,10 +71,15 @@ to be positive (default standard normal)} \item{kappa}{For negative binomial response, the negative binomial variance parameter. Constrained to be positive (default standard normal)} -\item{zi}{For zero-inflated poisson or negative binomial, the proportion of +\item{zi}{For zero-inflated poisson or negative binomial with no environmental +covariate effects upon the zero-inflation, the proportion of inflated zeros (default beta distribution with both alpha and beta parameters set to 1).} +\item{zi_betas}{For zero-inflated poisson or negative binomial with +environmental effects upon the zero-inflation, the covariate effects on the +zero-inflation on the logit scale} + \item{x}{Object of class \code{jsdmprior}} \item{...}{Currently unused} diff --git a/man/jsdm_sim_data.Rd b/man/jsdm_sim_data.Rd index 05eaecc..4538ed2 100644 --- a/man/jsdm_sim_data.Rd +++ b/man/jsdm_sim_data.Rd @@ -17,6 +17,8 @@ jsdm_sim_data( Ntrials = NULL, site_intercept = "none", beta_param = "unstruct", + zi_param = "constant", + zi_k = NULL, prior = jsdm_prior() ) @@ -57,6 +59,17 @@ supported currently.} \item{beta_param}{The parameterisation of the environmental covariate effects, by default \code{"unstruct"}. See details for further information.} +\item{zi_param}{For the zero-inflated families, whether the zero-inflation parameter +is a species-specific constant (default, \code{"constant"}), or varies by +environmental covariates (\code{"covariate"}).} + +\item{zi_k}{If \code{zi="covariate"}, the number of environmental covariates +that the zero-inflation parameter responds to. The default (\code{NULL}) is +that the zero-inflation parameter responds to exactly the same covariate matrix +as the mean parameter. Otherwise, a different set of random environmental +covariates are generated, plus an intercept (not included in zi_k) and used +to predict zero-inflation} + \item{prior}{Set of prior specifications from call to \code{\link[=jsdm_prior]{jsdm_prior()}}} \item{...}{Arguments passed to jsdm_sim_data} diff --git a/man/jsdm_stancode.Rd b/man/jsdm_stancode.Rd index 4800fbe..4ff0eec 100644 --- a/man/jsdm_stancode.Rd +++ b/man/jsdm_stancode.Rd @@ -11,7 +11,8 @@ jsdm_stancode( prior = jsdm_prior(), log_lik = TRUE, site_intercept = "none", - beta_param = "cor" + beta_param = "cor", + zi_param = "constant" ) \method{print}{jsdmstan_model}(x, ...) @@ -38,6 +39,10 @@ grouping)} \item{beta_param}{The parameterisation of the environmental covariate effects, by default \code{"cor"}. See details for further information.} +\item{zi_param}{For the zero-inflated families, whether the zero-inflation parameter +is a species-specific constant (default, \code{"constant"}), or varies by +environmental covariates (\code{"covariate"}).} + \item{x}{The jsdm_stancode object} \item{...}{Currently unused} diff --git a/man/loo.jsdmStanFit.Rd b/man/loo.jsdmStanFit.Rd index 01b0f68..53c024b 100644 --- a/man/loo.jsdmStanFit.Rd +++ b/man/loo.jsdmStanFit.Rd @@ -19,5 +19,6 @@ A list with class \code{c("psis_loo","loo")}, as detailed in the \description{ This function uses the \pkg{loo} package to compute PSIS-LOO CV, efficient approximate leave-one-out (LOO) cross-validation for Bayesian models using Pareto -smoothed importance sampling (PSIS). +smoothed importance sampling (PSIS). This requires that the model was fit using +\code{log_lik = TRUE}. } diff --git a/man/posterior_linpred.jsdmStanFit.Rd b/man/posterior_linpred.jsdmStanFit.Rd index 8d1e130..bda65e2 100644 --- a/man/posterior_linpred.jsdmStanFit.Rd +++ b/man/posterior_linpred.jsdmStanFit.Rd @@ -11,7 +11,6 @@ newdata = NULL, ndraws = NULL, draw_ids = NULL, - newdata_type = "X", list_index = "draws", ... ) @@ -31,9 +30,6 @@ number of samples.} \item{draw_ids}{The IDs of the draws to be used, as a numeric vector} -\item{newdata_type}{What form is the new data in, at the moment only -supplying covariates is supported.} - \item{list_index}{Whether to return the output list indexed by the number of draws (default), species, or site.} diff --git a/man/posterior_predict.jsdmStanFit.Rd b/man/posterior_predict.jsdmStanFit.Rd index 59d2f27..2c5ec7f 100644 --- a/man/posterior_predict.jsdmStanFit.Rd +++ b/man/posterior_predict.jsdmStanFit.Rd @@ -8,7 +8,6 @@ \method{posterior_predict}{jsdmStanFit}( object, newdata = NULL, - newdata_type = "X", ndraws = NULL, draw_ids = NULL, list_index = "draws", @@ -22,9 +21,6 @@ \item{newdata}{New data, by default \code{NULL} and uses original data} -\item{newdata_type}{What form is the new data in, at the moment only -supplying covariates is supported.} - \item{ndraws}{Number of draws, by default the number of samples in the posterior. Will be sampled randomly from the chains if fewer than the number of samples.} diff --git a/man/posterior_zipred.Rd b/man/posterior_zipred.Rd new file mode 100644 index 0000000..a65874a --- /dev/null +++ b/man/posterior_zipred.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_predict.R +\name{posterior_zipred} +\alias{posterior_zipred} +\title{Access the posterior distribution of the linear predictor for zero-inflation +parameter} +\usage{ +posterior_zipred( + object, + transform = FALSE, + newdata = NULL, + ndraws = NULL, + draw_ids = NULL, + list_index = "draws" +) +} +\arguments{ +\item{object}{The model object} + +\item{transform}{Should the linear predictor be transformed using the +inverse-link function. The default is \code{FALSE}, in which case the +untransformed linear predictor is returned.} + +\item{newdata}{New data, by default \code{NULL} and uses original data} + +\item{ndraws}{Number of draws, by default the number of samples in the +posterior. Will be sampled randomly from the chains if fewer than the +number of samples.} + +\item{draw_ids}{The IDs of the draws to be used, as a numeric vector} + +\item{list_index}{Whether to return the output list indexed by the number of +draws (default), species, or site.} +} +\value{ +A list of linear predictors. If list_index is \code{"draws"} (the +default) the list will have length equal to the number of draws with each +element of the list being a site x species matrix. If the list_index is +\code{"species"} the list will have length equal to the number of species +with each element of the list being a draws x sites matrix. If the +list_index is \code{"sites"} the list will have length equal to the number +of sites with each element of the list being a draws x species matrix. Note +that in the zero-inflated case this is only the linear predictor of the +non-zero-inflated part of the model. +} +\description{ +Extract the posterior draws of the linear predictor for the zero-inflation +parameter, possibly transformed by the inverse-link function. +} +\seealso{ +\code{\link[=posterior_predict.jsdmStanFit]{posterior_predict.jsdmStanFit()}} +} diff --git a/man/print.jsdmStanFamily.Rd b/man/print.jsdmStanFamily.Rd new file mode 100644 index 0000000..89259eb --- /dev/null +++ b/man/print.jsdmStanFamily.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/jsdmstan-families.R +\name{print.jsdmStanFamily} +\alias{print.jsdmStanFamily} +\title{Print jsdmStanFamily object} +\usage{ +\method{print}{jsdmStanFamily}(x, ...) +} +\arguments{ +\item{x}{A jsdmStanFamily object} + +\item{...}{Other arguments, not used at this stage.} +} +\description{ +Print jsdmStanFamily object +} diff --git a/man/stan_jsdm.Rd b/man/stan_jsdm.Rd index 7c2d468..426bf35 100644 --- a/man/stan_jsdm.Rd +++ b/man/stan_jsdm.Rd @@ -21,6 +21,8 @@ stan_jsdm(X, ...) site_groups = NULL, beta_param = "unstruct", Ntrials = NULL, + zi_param = "constant", + zi_X = NULL, save_data = TRUE, iter = 4000, log_lik = TRUE, @@ -71,6 +73,15 @@ default \code{"unstruct"}. See details for further information.} either a single integer which is assumed to be constant across sites or as a site-length vector of integers.} +\item{zi_param}{For the zero-inflated families, whether the zero-inflation parameter +is a species-specific constant (default, \code{"constant"}), or varies by +environmental covariates (\code{"covariate"}).} + +\item{zi_X}{If \code{zi = "covariate"}, the matrix of environmental predictors +that the zero-inflation is modelled in response to. If there is not already +an intercept column (identified by all values being equal to one), one will +be added to the front of the matrix.} + \item{save_data}{If the data used to fit the model should be saved in the model object, by default TRUE.} diff --git a/man/update.jsdmStanFit.Rd b/man/update.jsdmStanFit.Rd index c167b5f..eeb8a56 100644 --- a/man/update.jsdmStanFit.Rd +++ b/man/update.jsdmStanFit.Rd @@ -10,6 +10,7 @@ newX = NULL, newD = NULL, newNtrials = NULL, + newZi_X = NULL, save_data = TRUE, ... ) @@ -26,6 +27,11 @@ \item{newNtrials}{New number of trials (binomial model only), by default \code{NULL}} +\item{newZi_X}{New predictor data for the zi parameter in zero-inflated models, +by default \code{NULL}. In cases where the model was originally fit with the +same X and zi_X data and only newX is supplied to update.jsdmStanFit the zi_X +data will also be set to newX.} + \item{save_data}{Whether to save the data in the jsdmStanFit object, by default \code{TRUE}} diff --git a/tests/testthat/test-posterior_predict.R b/tests/testthat/test-posterior_predict.R index cb0e4fc..bac41ec 100644 --- a/tests/testthat/test-posterior_predict.R +++ b/tests/testthat/test-posterior_predict.R @@ -8,11 +8,6 @@ suppressWarnings(bern_fit <- stan_gllvm( )) test_that("posterior linpred errors appropriately", { - expect_error( - posterior_linpred(bern_fit, newdata_type = "F"), - "Currently only data on covariates is supported." - ) - expect_error(posterior_linpred(bern_fit, list_index = "bored")) @@ -29,10 +24,6 @@ test_that("posterior linpred errors appropriately", { }) test_that("posterior predictive errors appropriately", { - expect_error( - posterior_predict(bern_fit, newdata_type = "F"), - "Currently only data on covariates is supported." - ) expect_error( posterior_predict(bern_fit, draw_ids = c(-5,-1,0)), @@ -179,11 +170,11 @@ test_that("posterior_(lin)pred works with gllvm and zip", { set.seed(9598098) zinb_sim_data <- mglmm_sim_data(N = 100, S = 7, K = 2, family = "zi_neg_binomial", - site_intercept = "ungrouped") + site_intercept = "ungrouped", zi_param = "covariate") zinb_pred_data <- matrix(rnorm(100 * 2), nrow = 100) colnames(zinb_pred_data) <- c("V1", "V2") suppressWarnings(zinb_fit <- stan_mglmm( - dat_list = zinb_sim_data, family = "zi_neg_binomial", + dat_list = zinb_sim_data, family = "zi_neg_binomial",zi_param="covariate", refresh = 0, chains = 2, iter = 500 )) test_that("posterior_(lin)pred works with gllvm and zinb", { diff --git a/tests/testthat/test-sim_data_funs.R b/tests/testthat/test-sim_data_funs.R index 2e3356f..f038477 100644 --- a/tests/testthat/test-sim_data_funs.R +++ b/tests/testthat/test-sim_data_funs.R @@ -1,3 +1,4 @@ +set.seed(23359) test_that("gllvm_sim_data errors with bad inputs", { expect_error( gllvm_sim_data( @@ -23,6 +24,12 @@ test_that("gllvm_sim_data errors with bad inputs", { "Ntrials must be a positive integer" ) + expect_error( + gllvm_sim_data(N = 200, S = 8, D = 2, family = "zi_poisson", + zi_param = "covariate", zi_k = -2), + "zi_k must be either NULL or a positive integer" + ) + }) test_that("gllvm_sim_data returns a list of correct length", { @@ -88,9 +95,9 @@ test_that("mglmm_sim_data returns a list of correct length", { )) expect_length(gllvm_sim$Ntrials, 100) gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "zi_neg_binomial", method = "gllvm", - Ntrials = 19) + zi_param = "covariate") expect_named(gllvm_sim, c( - "Y", "pars", "N", "S", "D", "K", "X" + "Y", "pars", "N", "S", "D", "K", "X", "zi_k", "zi_X" )) expect_equal(dim(gllvm_sim$Y),c(100,12)) }) @@ -116,6 +123,15 @@ test_that("jsdm_sim_data returns all appropriate pars", { "betas","a_bar","sigma_a","a","L","LV","sigma_L","zi" )) + + gllvm_sim3 <- jsdm_sim_data(100,9,K=2,D=2,family = "zi_neg_bin", method = "gllvm", + beta_param = "unstruct", + site_intercept = "ungrouped", zi_param = "covariate", + zi_k = 1) + expect_named(gllvm_sim3$pars, c( + "betas","a_bar","sigma_a","a","L","LV","sigma_L","zi_betas","kappa" + )) + }) test_that("prior specification works", { diff --git a/tests/testthat/test-update.R b/tests/testthat/test-update.R index 1bee5e1..419b689 100644 --- a/tests/testthat/test-update.R +++ b/tests/testthat/test-update.R @@ -62,3 +62,27 @@ test_that("binomial models update", { )) expect_s3_class(gllvm_fit2, "jsdmStanFit") }) + +zip_data <- gllvm_sim_data(97,9,D = 2, family = "zi_poisson", + zi_param = "covariate") + +test_that("zi models update", { + suppressWarnings(gllvm_fit <- stan_gllvm(dat_list = zip_data, + family = "zi_poisson", + refresh = 0, chains = 1, iter = 200 + )) + expect_s3_class(gllvm_fit, "jsdmStanFit") + + expect_error(gllvm_fit2 <- update(gllvm_fit, newD = 3, + newZi_X = matrix(1:74, nrow = 74), + refresh = 0, chains = 1, iter = 200 + ), "Number of rows of zi_X") + + + suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 3, + newZi_X = matrix(rnorm(97), nrow = 97), + refresh = 0, chains = 1, iter = 200 + )) + expect_equal(ncol(gllvm_fit2$data_list$zi_X),2) + expect_s3_class(gllvm_fit2, "jsdmStanFit") +}) From 807ffe0e5b60c9fbb5011a5c758d06031f3004c2 Mon Sep 17 00:00:00 2001 From: Fiona Seaton Date: Wed, 21 Aug 2024 14:47:46 +0100 Subject: [PATCH 5/5] Improve testing and README --- R/stan_jsdm.R | 1 - README.md | 12 ++++++++++++ tests/testthat/test-jsdm_stancode.R | 22 ++++++++++++++++++++++ tests/testthat/test-posterior_predict.R | 5 +++++ tests/testthat/test-stan_jsdm.R | 20 ++++++++++++++++++++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/test-jsdm_stancode.R diff --git a/R/stan_jsdm.R b/R/stan_jsdm.R index 9d21625..d32a56e 100644 --- a/R/stan_jsdm.R +++ b/R/stan_jsdm.R @@ -476,7 +476,6 @@ model_to_jsdmstanfit <- function(model_fit, method, data_list, species_intercept fam$preds <- colnames(data_list$zi_X) fam$data_list <- list(zi_X = data_list$zi_X) } - print(str(fam$preds)) model_output <- list( fit = model_fit, diff --git a/README.md b/README.md index 815876c..c76b811 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,24 @@ [![R-CMD-check](https://github.com/NERC-CEH/jsdmstan/workflows/R-CMD-check/badge.svg)](https://github.com/NERC-CEH/jsdmstan/actions) [![Codecov test coverage](https://codecov.io/gh/NERC-CEH/jsdmstan/branch/main/graph/badge.svg)](https://codecov.io/gh/NERC-CEH/jsdmstan?branch=main) +[![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) This is an R package for running joint Species Distribution Models (jSDM) in [Stan](https://mc-stan.org/). jSDMs are models where multiple response variables (i.e. species) are fit at the same time, and the covariance between these species are used to inform the model results. For a review of jSDMs see Warton et al. (2015) So many variables: joint modelling in community ecology. *TREE*, 30:766-779 DOI: [10.1016/j.tree.2015.09.007](http://doi.org/10.1016/j.tree.2015.09.007). This package can fit data to a Multivariate Generalised Linear Mixed Model (MGLMM) or a Generalised Linear Latent Variable Model (GLLVM), and also provides functionality for simulating data under these scenarios and an interface to the [bayesplot](https://mc-stan.org/bayesplot/) package for a wide variety of plotting options. +# Installation + +This package can be installed using the [remotes](https://remotes.r-lib.org/index.html) package using the following code: + +``` +# install.packages("remotes") +remotes::install_github("NERC-CEH/jsdmstan") +``` + +# Using jsdmstan + Example code: ``` diff --git a/tests/testthat/test-jsdm_stancode.R b/tests/testthat/test-jsdm_stancode.R new file mode 100644 index 0000000..39feae7 --- /dev/null +++ b/tests/testthat/test-jsdm_stancode.R @@ -0,0 +1,22 @@ +# jsdm_stancode checks +test_that("jsdm_stancode errors appropriately", { + expect_error(jsdm_stancode(method = "mglmm", family = "nothing")) + + expect_error(jsdm_stancode(method = "mglmm", family = "poisson", + prior = list("betas" = "normal(0,1)")), + "Prior must be given as a jsdmprior object") +}) + +jsdm_code <- jsdm_stancode(method = "mglmm", family = "zi_neg_binomial", + log_lik = TRUE, site_intercept = "grouped", + zi_param = "covariate") + +test_that("jsdm_stancode print works", { + expect_output(print(jsdm_code)) +}) + + +# jsdmStanFamily checks +test_that("jsdmStanFamily print works", { + expect_output(print(jsdmStanFamily_empty())) +}) diff --git a/tests/testthat/test-posterior_predict.R b/tests/testthat/test-posterior_predict.R index bac41ec..8db03c8 100644 --- a/tests/testthat/test-posterior_predict.R +++ b/tests/testthat/test-posterior_predict.R @@ -177,6 +177,11 @@ suppressWarnings(zinb_fit <- stan_mglmm( dat_list = zinb_sim_data, family = "zi_neg_binomial",zi_param="covariate", refresh = 0, chains = 2, iter = 500 )) +test_that("zi_neg_bin print works okay", { + expect_output(print(zinb_fit$family), + "is modelled in response to") +}) + test_that("posterior_(lin)pred works with gllvm and zinb", { zinb_pred <- posterior_predict(zinb_fit, ndraws = 100) diff --git a/tests/testthat/test-stan_jsdm.R b/tests/testthat/test-stan_jsdm.R index 6dfab1f..f19764f 100644 --- a/tests/testthat/test-stan_jsdm.R +++ b/tests/testthat/test-stan_jsdm.R @@ -299,3 +299,23 @@ test_that("site intercept models run", { expect_s3_class(mglmm_fit, "jsdmStanFit") }) + + +set.seed(9598098) +zip_sim_data <- gllvm_sim_data(N = 100, S = 7, D = 2, K = 2, family = "zi_poisson", + zi_param = "covariate") + +test_that("zi_poisson works okay", { + suppressWarnings(zip_fit <- stan_gllvm( + dat_list = zip_sim_data, family = "zi_poisson",zi_param="covariate", + refresh = 0, chains = 2, iter = 200 + )) + expect_s3_class(zip_fit, "jsdmStanFit") + expect_s3_class(zip_fit$family, "jsdmStanFamily") + expect_output(print(zip_fit$family), + "is modelled in response to") + expect_named(zip_fit$family, + c("family" ,"params" ,"params_dataresp","preds","data_list")) + expect_named(zip_fit$family$data_list, "zi_X") +}) +