diff --git a/DESCRIPTION b/DESCRIPTION index 1d06d51..04b159b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: jsdmstan Title: Fitting jSDMs in Stan -Version: 0.3.0 +Version: 0.3.0.9000 Authors@R: person("Fiona", "Seaton", , "fseaton@ceh.ac.uk", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-2022-7451")) @@ -12,7 +12,7 @@ License: GPL (>= 3) Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Biarch: true Depends: R (>= 3.4.0) diff --git a/NAMESPACE b/NAMESPACE index 2e7b848..8c528db 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,7 @@ S3method(plot,jsdmStanFit) S3method(posterior_linpred,jsdmStanFit) S3method(posterior_predict,jsdmStanFit) S3method(pp_check,jsdmStanFit) +S3method(print,jsdmStanFamily) S3method(print,jsdmStanFit) S3method(print,jsdmprior) S3method(print,jsdmstan_model) @@ -40,6 +41,7 @@ export(nuts_params) export(ordiplot) export(posterior_linpred) export(posterior_predict) +export(posterior_zipred) export(pp_check) export(rgampois) export(rhat) diff --git a/R/jsdm_stancode.R b/R/jsdm_stancode.R index 3a02b5a..a4efe53 100644 --- a/R/jsdm_stancode.R +++ b/R/jsdm_stancode.R @@ -13,8 +13,11 @@ #' can be modified using the prior object. #' #' @param method The method, one of \code{"gllvm"} or \code{"mglmm"} -#' @param family The family, one of "\code{"gaussian"}, \code{"bernoulli"}, -#' \code{"poisson"} or \code{"neg_binomial"} +#' @param family is the response family, must be one of \code{"gaussian"}, +#' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +#' \code{"bernoulli"}, \code{"zi_poisson"}, or +#' \code{"zi_neg_binomial"}. Regular expression +#' matching is supported. #' @param prior The prior, given as the result of a call to [jsdm_prior()] #' @param log_lik Whether the log likelihood should be calculated in the generated #' quantities (by default \code{TRUE}), required for loo @@ -24,6 +27,9 @@ #' grouping) #' @param beta_param The parameterisation of the environmental covariate effects, by #' default \code{"cor"}. See details for further information. +#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter +#' is a species-specific constant (default, \code{"constant"}), or varies by +#' environmental covariates (\code{"covariate"}). #' #' @return A character vector of Stan code, class "jsdmstan_model" #' @export @@ -34,13 +40,15 @@ #' jsdm_stancode <- function(method, family, prior = jsdm_prior(), log_lik = TRUE, site_intercept = "none", - beta_param = "cor") { + beta_param = "cor", zi_param = "constant") { # checks family <- match.arg(family, c("gaussian", "bernoulli", "poisson", - "neg_binomial","binomial")) + "neg_binomial","binomial","zi_poisson", + "zi_neg_binomial")) method <- match.arg(method, c("gllvm", "mglmm")) beta_param <- match.arg(beta_param, c("cor","unstruct")) site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped")) + zi_param <- match.arg(zi_param, c("constant","covariate")) if (class(prior)[1] != "jsdmprior") { stop("Prior must be given as a jsdmprior object") } @@ -49,7 +57,7 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), scode <- .modelcode( method = method, family = family, phylo = FALSE, prior = prior, log_lik = log_lik, site_intercept = site_intercept, - beta_param = beta_param + beta_param = beta_param, zi_param = zi_param ) class(scode) <- c("jsdmstan_model", "character") return(scode) @@ -57,7 +65,7 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(), .modelcode <- function(method, family, phylo, prior, log_lik, site_intercept, - beta_param) { + beta_param, zi_param) { model_functions <- " " data <- paste( @@ -81,12 +89,29 @@ ifelse(site_intercept == "grouped", "bernoulli" = "int", "neg_binomial" = "int", "poisson" = "int", + "zi_poisson" = "int", + "zi_neg_binomial" = "int", "binomial" = "int" ), "Y[N,S]; //Species matrix", ifelse(family == "binomial", " - int Ntrials[N]; // Number of trials","") - ) + int Ntrials[N]; // Number of trials",""), + ifelse(grepl("zi_", family)," + int N_zero[S]; // number of zeros per species + int N_nonzero[S]; //number of nonzeros per species + int Sum_nonzero; //Total number of nonzeros across all species + int Sum_zero; //Total number of zeros across all species + int Y_nz[Sum_nonzero]; //Y values for nonzeros + int ss[Sum_nonzero]; //species index for Y_nz + int nn[Sum_nonzero]; //site index for Y_nz + int sz[Sum_zero]; //species index for Y_z + int nz[Sum_zero]; //site index for Y_z",""), +ifelse(grepl("zi_", family) & zi_param == "covariate"," + int zi_k; //number of covariates for env effects on zi + matrix[N, zi_k] zi_X; //environmental covariate matrix for zi","") +) + + transformed_data <- ifelse(method == "gllvm", " // Ensures identifiability of the model - no rotation of factors int M; @@ -131,7 +156,19 @@ ifelse(site_intercept == "grouped", "bernoulli" = "", "neg_binomial" = " real kappa[S]; // neg_binomial parameters", - "poisson" = "" + "poisson" = "", + "zi_poisson" = switch(zi_param, + "constant" = " + real zi[S]; // zero-inflation parameter", + "covariate" = " + matrix[zi_k,S] zi_betas; //environmental effects for zi"), + "zi_neg_binomial" = switch(zi_param, + "constant" = " + real kappa[S]; // neg_binomial parameters + real zi[S]; // zero-inflation parameter", + "covariate" = " + real kappa[S]; // neg_binomial parameters + matrix[zi_k,S] zi_betas; //environmental effects for zi") ) pars <- paste( @@ -215,10 +252,32 @@ ifelse(site_intercept == "grouped", ") model <- paste(" matrix[N,S] mu; - ", switch(method, + ", ifelse(grepl("zi_",family),paste0(" + real mu_nz[Sum_nonzero]; + real mu_z[Sum_zero]; + int pos; + int neg;",switch(zi_param,"constant" = "", + "covariate" = " + real zi_nz[Sum_nonzero]; + real zi_z[Sum_zero];")),""), + switch(method, "gllvm" = gllvm_model, "mglmm" = mglmm_model - )) + ),ifelse(grepl("zi_",family),paste0(ifelse(zi_param == "covariate", " + matrix[N,S] zi = zi_X * zi_betas;","")," + for(i in 1:Sum_nonzero){ + mu_nz[i] = mu[nn[i],ss[i]];", + switch(zi_param, "constant" = "", + "covariate" = " + zi_nz[i] = zi[nn[i],ss[i]];")," + } + for(i in 1:Sum_zero){ + mu_z[i] = mu[nz[i],sz[i]];", + switch(zi_param, "constant" = "", + "covariate" = " + zi_z[i] = zi[nz[i],sz[i]];")," + } + "),"")) model_priors <- paste( ifelse(site_intercept %in% c("ungrouped","grouped"), paste(" // Site-level intercept priors @@ -263,10 +322,26 @@ ifelse(site_intercept == "grouped", "), "bern" = "", "poisson" = "", - "binomial" = "" + "binomial" = "", + "zi_poisson" = switch(zi_param,"constant" = paste(" + //zero-inflation parameter + zi ~ ", prior[["zi"]], "; +"), "covariate" = paste(" + //zero-inflation parameter + to_vector(zi_betas) ~ ", prior[["zi_betas"]], "; +")), +"zi_neg_binomial" = switch(zi_param, "constant" = paste(" + //zero-inflation parameter + zi ~ ", prior[["zi"]], "; + kappa ~ ", prior[["kappa"]], "; +"), "covariate" = paste(" + //zero-inflation parameter + to_vector(zi_betas) ~ ", prior[["zi_betas"]], "; + kappa ~ ", prior[["kappa"]], "; +") ) - ) - model_pt2 <- paste( + )) + model_pt2 <- if(!grepl("zi_", family)){ paste( " for(i in 1:N) Y[i,] ~ ", switch(family, @@ -276,7 +351,36 @@ ifelse(site_intercept == "grouped", "poisson" = "poisson_log(mu[i,]);", "binomial" = "binomial_logit(Ntrials[i], mu[i,]);" ) - ) + )} else{paste(" + pos = 1; + neg = 1; + for(s in 1:S){ + target + += N_zero[s] + * log_sum_exp(", + switch(zi_param,"constant" = "log(zi[s]), + log1m(zi[s]) + +", + "covariate" = "bernoulli_logit_lpmf(1 | segment(zi_z, neg, N_zero[s])), + bernoulli_logit_lpmf(0 | segment(zi_z, neg, N_zero[s])) + +"), + switch(family, + "zi_poisson" = "poisson_log_lpmf(0 | segment(mu_z, neg, N_zero[s])));", + "zi_neg_binomial" = "neg_binomial_2_log_lpmf(0 | segment(mu_z, neg, N_zero[s]), kappa[s]));")," + target += N_nonzero[s] * ",switch(zi_param, + "constant" = "log1m(zi[s]);", + "covariate" = "bernoulli_logit_lpmf(0 | segment(zi_nz, pos, N_nonzero[s]));")," + target +=", + switch(family, + "zi_poisson" = "poisson_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) | + segment(mu_nz, pos, N_nonzero[s]));", + "zi_neg_binomial" = "neg_binomial_2_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) | + segment(mu_nz, pos, N_nonzero[s]), kappa[s]);")," + pos = pos + N_nonzero[s]; + neg = neg + N_zero[s]; + } +") + } generated_quantities <- paste( ifelse(isTRUE(log_lik), " @@ -298,7 +402,9 @@ ifelse(site_intercept == "grouped", }", ""), ifelse(isTRUE(log_lik), paste( " { - matrix[N, S] linpred;", switch(site_intercept, "ungrouped" = paste(" + matrix[N, S] linpred;",ifelse(grepl("zi", family) & zi_param == "covariate"," + matrix[N,S] zi = zi_X * zi_betas;",""), + switch(site_intercept, "ungrouped" = paste(" linpred = rep_matrix(a_bar + a * sigma_a, S) + (X * betas) +", switch(method, "gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'", @@ -322,14 +428,45 @@ ifelse(site_intercept == "grouped", ), "; "))," for(i in 1:N) { - for(j in 1:S) { - log_lik[i, j] = ", + for(j in 1:S) {", switch(family, - "gaussian" = "normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);", - "bernoulli" = "bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);", - "neg_binomial" = "neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);", - "poisson" = "poisson_log_lpmf(Y[i, j] | linpred[i, j]);", - "binomial" = "binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);" + "gaussian" = "log_lik[i, j] = normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);", + "bernoulli" = "log_lik[i, j] = bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);", + "neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);", + "poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);", + "binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);", + "zi_poisson" = switch(zi_param,"constant" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), + bernoulli_lpmf(0 |zi[j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j])); + } else { + log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }", + "covariate" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]), + bernoulli_logit_lpmf(0 |zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j])); + } else { + log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }"), + "zi_neg_binomial" = switch(zi_param,"constant" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]), + bernoulli_lpmf(0 |zi[j]) + + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j])); + } else { + log_lik[i, j] = bernoulli_lpmf(0 | zi[j]) + + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]); + }", + "covariate" = "if (Y[i,j] == 0){ + log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]), + bernoulli_logit_lpmf(0 |zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j])); + } else { + log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j]) + + poisson_log_lpmf(Y[i,j] | linpred[i,j]); + }") )," } } diff --git a/R/jsdmstan-families.R b/R/jsdmstan-families.R new file mode 100644 index 0000000..2d2f56f --- /dev/null +++ b/R/jsdmstan-families.R @@ -0,0 +1,66 @@ +#' jsdmStanFamily class +#' +#' This is the jsdmStanFamily class, which occupies a slot within any +#' jsdmStanFit object. +#' +#' @name jsdmStanFamily +#' +#' @section Elements for \code{jsdmStanFamily} objects: +#' \describe{ +#' \item{\code{family}}{ +#' A length one character vector describing family used to fit object. Options +#' are \code{"gaussian"}, \code{"poisson"}, \code{"bernoulli"}, +#' \code{"neg_binomial"}, \code{"binomial"}, \code{"zi_poisson"}, +#' \code{"zi_neg_binomial"}, or \code{"multiple"}. +#' } +#' \item{\code{params}}{ +#' A character vector that includes all the names of the family-specific parameters. +#' } +#' \item{\code{params_dataresp}}{ +#' A character vector that includes any named family-specific parameters that are +#' modelled in response to data. +#' } +#' \item{\code{preds}}{ +#' A character vector of the measured predictors included if family parameters +#' are modelled in response to data. If family parameters are not modelled in +#' response to data this is left empty. +#' } +#' \item{\code{data_list}}{ +#' A list containing the original data used to fit the model +#' (empty when save_data is set to \code{FALSE} or family parameters are not +#' modelled in response to data). +#' } +#' } +#' +jsdmStanFamily_empty <- function(){ + res <- list(family = character(), + params = character(), + params_dataresp= character(), + preds = character(), + data_list = list()) + class(res) <- "jsdmStanFamily" + return(res) +} + +# jsdmStanFamily methods + +#' Print jsdmStanFamily object +#' +#' @param x A jsdmStanFamily object +#' @param ... Other arguments, not used at this stage. +#' +#' @export +print.jsdmStanFamily <- function(x, ...){ + cat(paste("Family:", x$family, "\n", + ifelse(length(x$params)>0, + paste("With parameters:", + paste0(x$params, sep = ", "),"\n"), + ""))) + if(length(x$params_dataresp)>0){ + cat(paste("Family-specific parameter", + paste0(x$params_dataresp,sep=", "), + "is modelled in response to", length(x$preds), + "predictors. These are named:", + paste0(x$preds, sep = ", "))) + } +} diff --git a/R/jsdmstanfit-class.R b/R/jsdmstanfit-class.R index fa0a4be..d364211 100644 --- a/R/jsdmstanfit-class.R +++ b/R/jsdmstanfit-class.R @@ -10,7 +10,7 @@ #' A length one character vector describing type of jSDM #' } #' \item{\code{family}}{ -#' A character vector describing response family +#' A jsdmStanFamily object describing characteristics of family #' } #' \item{\code{species}}{ #' A character vector of the species names @@ -35,7 +35,7 @@ jsdmStanFit_empty <- function() { res <- list( jsdm_type = "None", - family = character(), + family = jsdmStanFamily_empty(), species = character(), sites = character(), preds = character(), @@ -77,6 +77,7 @@ print.jsdmStanFit <- function(x, ...) { " Number of species: ", length(x$species), "\n", " Number of sites: ", length(x$sites), "\n", " Number of predictors: ", length(x$preds), "\n", + print(x$family), "\n", "Model run on ", length(x$fit@stan_args), " chains with ", x$fit@stan_args[[1]]$iter, " iterations per chain (", diff --git a/R/loo.R b/R/loo.R index a91b14a..754241d 100644 --- a/R/loo.R +++ b/R/loo.R @@ -2,7 +2,8 @@ #' #' This function uses the \pkg{loo} package to compute PSIS-LOO CV, efficient #' approximate leave-one-out (LOO) cross-validation for Bayesian models using Pareto -#' smoothed importance sampling (PSIS). +#' smoothed importance sampling (PSIS). This requires that the model was fit using +#' \code{log_lik = TRUE}. #' #' @param x The jsdmStanFit model object #' @param ... Other arguments passed to the \code{\link[loo]{loo}} function diff --git a/R/posterior_predict.R b/R/posterior_predict.R index 50b58da..5fd6303 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -19,9 +19,6 @@ #' #' @param draw_ids The IDs of the draws to be used, as a numeric vector #' -#' @param newdata_type What form is the new data in, at the moment only -#' supplying covariates is supported. -#' #' @param list_index Whether to return the output list indexed by the number of #' draws (default), species, or site. #' @param ... Currently unused @@ -32,7 +29,8 @@ #' list will have length equal to the number of species with each element of #' the list being a draws x sites matrix. If the list_index is \code{"sites"} the #' list will have length equal to the number of sites with each element of the -#' list being a draws x species matrix. +#' list being a draws x species matrix. Note that in the zero-inflated case this is +#' only the linear predictor of the non-zero-inflated part of the model. #' #' @seealso [posterior_predict.jsdmStanFit()] #' @@ -41,23 +39,15 @@ #' @export posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, newdata = NULL, ndraws = NULL, - draw_ids = NULL, newdata_type = "X", + draw_ids = NULL, list_index = "draws", ...) { - if (newdata_type != "X") { - stop("Currently only data on covariates is supported.") - } stopifnot(is.logical(transform)) - if (isTRUE(transform) & object$family == "gaussian") { + if (isTRUE(transform) & object$family$family == "gaussian") { warning("No inverse-link transform performed for Gaussian response models.") } - if (!is.null(ndraws) & !is.null(draw_ids)) { - message("Both ndraws and draw_ids have been specified, ignoring ndraws") - } - if (!is.null(draw_ids)) { - if (any(!is.wholenumber(draw_ids))) { - stop("draw_ids must be a vector of positive integers") - } - } + + foopred_checks(object = object, transform = transform, newdata = newdata, + ndraws = ndraws, draw_ids = draw_ids, list_index = list_index) list_index <- match.arg(list_index, c("draws", "species", "sites")) @@ -68,19 +58,12 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, method <- object$jsdm_type orig_data_used <- is.null(newdata) - if (is.null(newdata) & length(object$data_list) == 0) { - stop(paste( - "Original data must be included in model object if no new data", - "is provided." - )) - } if (is.null(newdata)) { newdata <- object$data_list$X } newdata <- validate_newdata(newdata, - preds = object$preds, - newdata_type = newdata_type + preds = object$preds ) model_pars <- "betas" @@ -95,44 +78,27 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, } model_est <- extract(object, pars = model_pars) - n_iter <- dim(model_est[[1]])[1] - if (!is.null(draw_ids)) { - if (max(draw_ids) > n_iter) { - stop(paste( - "Maximum of draw_ids (", max(draw_ids), - ") is greater than number of iterations (", n_iter, ")" - )) - } + draw_id <- draw_id_check(draw_ids = draw_ids, n_iter = n_iter, ndraws = ndraws) - draw_id <- draw_ids - } else { - if (!is.null(ndraws)) { - if (n_iter < ndraws) { - warning(paste( - "There are fewer samples than ndraws specified, defaulting", - "to using all iterations" - )) - ndraws <- n_iter - } - draw_id <- sample.int(n_iter, ndraws) - model_est <- lapply(model_est, function(x) { - switch(length(dim(x)), - `1` = x[draw_id, drop = FALSE], - `2` = x[draw_id, , drop = FALSE], - `3` = x[draw_id, , , drop = FALSE] - ) - }) - } else { - draw_id <- seq_len(n_iter) - } - } + model_est <- lapply(model_est, function(x) { + switch(length(dim(x)), + `1` = x[draw_id, drop = FALSE], + `2` = x[draw_id, , drop = FALSE], + `3` = x[draw_id, , , drop = FALSE] + ) + }) model_pred_list <- lapply(seq_along(draw_id), function(d) { if (method == "gllvm") { if (orig_data_used) { - LV_sum <- t((model_est$Lambda[d, , ] * model_est$sigma_L[d]) %*% model_est$LV[d, , ]) + if(object$n_latent>1){ + LV_sum <- t((model_est$Lambda[d, , ] * model_est$sigma_L[d]) %*% model_est$LV[d, , ]) + } else{ + LV_sum <- t((matrix(model_est$Lambda[d, , ], ncol = 1) * model_est$sigma_L[d]) + %*% matrix(model_est$LV[d, , ], nrow = 1)) + } } else { LV_sum <- 0 } @@ -157,12 +123,14 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, ) if (isTRUE(transform)) { mu <- apply(mu, 1:2, function(x) { - switch(object$family, + switch(object$family$family, "gaussian" = x, "bernoulli" = inv_logit(x), "poisson" = exp(x), "neg_binomial" = exp(x), - "binomial" = inv_logit(x) + "binomial" = inv_logit(x), + "zi_poisson" = exp(x), + "zi_neg_binomial" = exp(x) ) }) } @@ -177,62 +145,89 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE, return(model_pred_list) } -#' Draw from the posterior predictive distribution +#'Draw from the posterior predictive distribution #' -#' Draw from the posterior predictive distribution of the outcome. +#'Draw from the posterior predictive distribution of the outcome. #' -#' @aliases posterior_predict +#'@aliases posterior_predict #' -#' @inheritParams posterior_linpred.jsdmStanFit +#'@inheritParams posterior_linpred.jsdmStanFit #' -#' @param Ntrials For the binomial distribution the number of trials, given as -#' either a single integer which is assumed to be constant across sites or as -#' a site-length vector of integers. +#'@param Ntrials For the binomial distribution the number of trials, given as either +#' a single integer which is assumed to be constant across sites or as a site-length +#' vector of integers. #' -#' @return A list of linear predictors. If list_index is \code{"draws"} (the default) -#' the list will have length equal to the number of draws with each element of -#' the list being a site x species matrix. If the list_index is \code{"species"} the -#' list will have length equal to the number of species with each element of -#' the list being a draws x sites matrix. If the list_index is \code{"sites"} the -#' list will have length equal to the number of sites with each element of -#' the list being a draws x species matrix. +#'@param include_zi For the zero-inflated poisson distribution, whether to include +#' the zero-inflation in the prediction. Defaults to \code{TRUE}. #' -#' @seealso [posterior_linpred.jsdmStanFit()] +#'@return A list of linear predictors. If list_index is \code{"draws"} (the default) +#' the list will have length equal to the number of draws with each element of the +#' list being a site x species matrix. If the list_index is \code{"species"} the +#' list will have length equal to the number of species with each element of the +#' list being a draws x sites matrix. If the list_index is \code{"sites"} the list +#' will have length equal to the number of sites with each element of the list being +#' a draws x species matrix. #' -#' @importFrom rstantools posterior_predict -#' @export posterior_predict -#' @export +#'@seealso [posterior_linpred.jsdmStanFit()] +#' +#'@importFrom rstantools posterior_predict +#'@export posterior_predict +#'@export posterior_predict.jsdmStanFit <- function(object, newdata = NULL, - newdata_type = "X", ndraws = NULL, + ndraws = NULL, draw_ids = NULL, list_index = "draws", - Ntrials = NULL, ...) { - transform <- ifelse(object$family == "gaussian", FALSE, TRUE) + Ntrials = NULL, + include_zi = TRUE, ...) { + transform <- ifelse(object$family$family == "gaussian", FALSE, TRUE) + if (!is.null(ndraws) & !is.null(draw_ids)) { + message("Both ndraws and draw_ids have been specified, ignoring ndraws") + } + if (!is.null(draw_ids)) { + if (any(!is.wholenumber(draw_ids))) { + stop("draw_ids must be a vector of positive integers") + } + } + n_iter <- length(object$fit@stan_args)*(object$fit@stan_args[[1]]$iter - + object$fit@stan_args[[1]]$warmup) + draw_id <- draw_id_check(draw_ids = draw_ids, n_iter = n_iter, ndraws = ndraws) + + post_linpred <- posterior_linpred(object, - newdata = newdata, ndraws = ndraws, - newdata_type = newdata_type, draw_ids = draw_ids, + newdata = newdata, + draw_ids = draw_id, transform = transform, list_index = "draws" ) - if (object$family == "gaussian") { - mod_sigma <- rstan::extract(object$fit, pars = "sigma", permuted = FALSE) - } - if (object$family == "neg_binomial") { - mod_kappa <- rstan::extract(object$fit, pars = "kappa", permuted = FALSE) - } - if(object$family == "binomial"){ + if (object$family$family == "gaussian") { + mod_sigma <- extract(object, pars = "sigma")[[1]][draw_id,] + } else if(object$family$family == "binomial"){ if(is.null(newdata)) { Ntrials <- object$data_list$Ntrials } else { Ntrials <- ntrials_check(Ntrials, nrow(newdata)) } } + if(grepl("zi",object$family$family) & !("zi" %in% object$family$params_dataresp) & isTRUE(include_zi)){ + mod_zi <- extract(object, pars = "zi")[[1]][draw_id,] + } + if(grepl("neg_binomial",object$family$family)) { + mod_kappa <- extract(object, pars = "kappa")[[1]][draw_id,] + } + + if("zi" %in% object$family$params_dataresp & isTRUE(include_zi)){ + post_zipred <- posterior_zipred(object, + newdata = newdata, + draw_ids = draw_id, + transform = transform, list_index = "draws" + ) + } n_sites <- length(object$sites) n_species <- length(object$species) post_pred <- lapply(seq_along(post_linpred), - function(x, family = object$family) { + function(x, family = object$family$family) { x2 <- post_linpred[[x]] if(family == "binomial"){ for(i in 1:nrow(x2)){ @@ -240,14 +235,37 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, x2[i,j] <- stats::rbinom(1, Ntrials[i], x2[i,j]) } } + } else if (grepl("zi_",family) & "zi" %in% object$family$params_dataresp & + isTRUE(include_zi)){ + zi2 <- post_zipred[[x]] + for(i in seq_len(nrow(x2))){ + for(j in seq_len(ncol(x2))){ + x2[i,j] <- switch( + object$family$family, + "zi_poisson" = (1-stats::rbinom(1, 1, zi2[x,j]))*stats::rpois(1, x2[i,j]), + "zi_neg_binomial" = (1-stats::rbinom(1, 1, zi2[x,j]))*rgampois(1, x2[i,j], mod_kappa[x,j]) + ) + } + } } else { for(i in seq_len(nrow(x2))){ for(j in seq_len(ncol(x2))){ - x2[i,j] <- switch(object$family, - "gaussian" = stats::rnorm(1, x2[i,j], mod_sigma[j]), - "bernoulli" = stats::rbinom(1, 1, x2[i,j]), - "poisson" = stats::rpois(1, x2[i,j]), - "neg_binomial" = rgampois(1, x2[i,j], mod_kappa[j]) + x2[i,j] <- switch( + object$family$family, + "gaussian" = stats::rnorm(1, x2[i,j], mod_sigma[x,j]), + "bernoulli" = stats::rbinom(1, 1, x2[i,j]), + "poisson" = stats::rpois(1, x2[i,j]), + "neg_binomial" = rgampois(1, x2[i,j], mod_kappa[x,j]), + "zi_poisson" = if(isTRUE(include_zi)){ + (1-stats::rbinom(1, 1, mod_zi[x,j]))*stats::rpois(1, x2[i,j]) + } else { + stats::rpois(1, x2[i,j]) + }, + "zi_neg_binomial" = if(isTRUE(include_zi)){ + (1-stats::rbinom(1, 1, mod_zi[x,j]))*rgampois(1, x2[i,j], mod_kappa[x,j]) + } else { + rgampois(1, x2[i,j], mod_kappa[x,j]) + } ) } } @@ -263,10 +281,93 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL, } +#' Access the posterior distribution of the linear predictor for zero-inflation +#' parameter +#' +#' Extract the posterior draws of the linear predictor for the zero-inflation +#' parameter, possibly transformed by the inverse-link function. +#' +#' +#' @inheritParams posterior_linpred.jsdmStanFit +#' +#' @return A list of linear predictors. If list_index is \code{"draws"} (the +#' default) the list will have length equal to the number of draws with each +#' element of the list being a site x species matrix. If the list_index is +#' \code{"species"} the list will have length equal to the number of species +#' with each element of the list being a draws x sites matrix. If the +#' list_index is \code{"sites"} the list will have length equal to the number +#' of sites with each element of the list being a draws x species matrix. Note +#' that in the zero-inflated case this is only the linear predictor of the +#' non-zero-inflated part of the model. +#' +#' @seealso [posterior_predict.jsdmStanFit()] +#' +#' @export +posterior_zipred <- function(object, transform = FALSE, + newdata = NULL, ndraws = NULL, + draw_ids = NULL, + list_index = "draws"){ + if(!("zi" %in% object$family$params_dataresp)){ + stop(paste("This function only works upon models with a zero-inflated family", + "and where the zero-inflation is responsive to covariates.")) + } + foopred_checks(object = object, transform = transform, newdata = newdata, + ndraws = ndraws, draw_ids = draw_ids, list_index = list_index) + list_index <- match.arg(list_index, c("draws", "species", "sites")) + + n_sites <- length(object$sites) + n_species <- length(object$species) + n_preds <- length(object$family$preds) + method <- object$jsdm_type + orig_data_used <- is.null(newdata) + + if (is.null(newdata)) { + newdata <- object$family$data_list$zi_X + } + + newdata <- validate_newdata(newdata, + preds = object$family$preds + ) + + model_pars <- "zi_betas" + + model_est <- extract(object, pars = model_pars) + n_iter <- dim(model_est[[1]])[1] + + draw_id <- draw_id_check(draw_ids = draw_ids, n_iter = n_iter, ndraws = ndraws) + + model_est <- lapply(model_est, function(x) { + switch(length(dim(x)), + `1` = x[draw_id, drop = FALSE], + `2` = x[draw_id, , drop = FALSE], + `3` = x[draw_id, , , drop = FALSE] + ) + }) + + model_pred_list <- lapply(seq_along(draw_id), function(d) { + + if (is.vector(newdata)) { + newdata <- matrix(newdata, ncol = 1) + } + zi <- newdata %*% model_est$zi_betas[d, , ] + + if (isTRUE(transform)) { + zi <- apply(zi, c(1,2), inv_logit) + } + + return(zi) + }) + + if (list_index != "draws") { + model_pred_list <- switch_indices(model_pred_list, list_index) + } + + return(model_pred_list) +} # internal ~~~~ -validate_newdata <- function(newdata, preds, newdata_type) { +validate_newdata <- function(newdata, preds) { preds_nointercept <- preds[preds != "(Intercept)"] if (!all(preds_nointercept %in% colnames(newdata))) { @@ -278,7 +379,7 @@ validate_newdata <- function(newdata, preds, newdata_type) { newdata <- newdata[, preds_nointercept] if ("(Intercept)" %in% preds) { - newdata <- cbind(`(Intercept)` = 1, newdata) + newdata <- cbind("(Intercept)" = 1, newdata) newdata <- newdata[, preds] } return(newdata) @@ -297,3 +398,50 @@ switch_indices <- function(res_list, list_index) { stop("List index not valid") } } + +draw_id_check <- function(draw_ids, n_iter, ndraws){ + if (!is.null(draw_ids)) { + if (max(draw_ids) > n_iter) { + stop(paste( + "Maximum of draw_ids (", max(draw_ids), + ") is greater than number of iterations (", n_iter, ")" + )) + } + + draw_id <- draw_ids + } else { + if (!is.null(ndraws)) { + if (n_iter < ndraws) { + warning(paste( + "There are fewer samples than ndraws specified, defaulting", + "to using all iterations" + )) + ndraws <- n_iter + } + draw_id <- sample.int(n_iter, ndraws) + } else { + draw_id <- seq_len(n_iter) + } + } + return(draw_id) +} + +foopred_checks <- function(object, transform, draw_ids, ndraws, list_index, + newdata){ + stopifnot(is.logical(transform)) + if (!is.null(ndraws) & !is.null(draw_ids)) { + message("Both ndraws and draw_ids have been specified, ignoring ndraws") + } + if (!is.null(draw_ids)) { + if (any(!is.wholenumber(draw_ids))) { + stop("draw_ids must be a vector of positive integers") + } + } + if (is.null(newdata) & length(object$data_list) == 0) { + stop(paste( + "Original data must be included in model object if no new data", + "is provided." + )) + } + return(NULL) +} diff --git a/R/pp_check.R b/R/pp_check.R index 205d474..d4cfd4c 100644 --- a/R/pp_check.R +++ b/R/pp_check.R @@ -146,10 +146,13 @@ pp_check.jsdmStanFit <- function(object, plotfun = "dens_overlay", species = NUL # prepare plotting arguments ppc_args <- list(y = y, yrep = yrep) - for_pred <- union( - names(dots) %in% names(formals(jsdm_statsummary)), - names(dots) %in% names(formals(posterior_linpred.jsdmStanFit)) + for_pred <- names(dots) %in% union(union( + names(formals(jsdm_statsummary)), + names(formals(posterior_linpred.jsdmStanFit)) + ), + names(formals(posterior_predict.jsdmStanFit)) ) + ppc_args <- c(ppc_args, dots[!for_pred]) do.call(ppc_fun, ppc_args) @@ -357,7 +360,7 @@ multi_pp_check <- function(object, plotfun = "dens_overlay", species = NULL, post_args$object <- object post_args$list_index <- "draws" post_args$draw_ids <- draw_ids - post_args$ndraws <- ndraws + # post_args$ndraws <- ndraws post_res <- do.call(post_fun, post_args) diff --git a/R/prior.R b/R/prior.R index 3108dcf..b643385 100644 --- a/R/prior.R +++ b/R/prior.R @@ -54,6 +54,13 @@ #' to be positive (default standard normal) #' @param kappa For negative binomial response, the negative binomial variance #' parameter. Constrained to be positive (default standard normal) +#' @param zi For zero-inflated poisson or negative binomial with no environmental +#' covariate effects upon the zero-inflation, the proportion of +#' inflated zeros (default beta distribution with both alpha and beta parameters +#' set to 1). +#' @param zi_betas For zero-inflated poisson or negative binomial with +#' environmental effects upon the zero-inflation, the covariate effects on the +#' zero-inflation on the logit scale #' #' @return An object of class \code{"jsdmprior"} taking the form of a named list #' @export @@ -76,7 +83,9 @@ jsdm_prior <- function(sigmas_preds = "normal(0,1)", L = "normal(0,1)", sigma_L = "normal(0,1)", sigma = "normal(0,1)", - kappa = "normal(0,1)") { + kappa = "normal(0,1)", + zi = "beta(1,1)", + zi_betas = "normal(0,1)") { res <- list( sigmas_preds = sigmas_preds, z_preds = z_preds, cor_preds = cor_preds, betas = betas, @@ -84,7 +93,8 @@ jsdm_prior <- function(sigmas_preds = "normal(0,1)", sigmas_species = sigmas_species, z_species = z_species, cor_species = cor_species, LV = LV, L = L, sigma_L = sigma_L, - sigma = sigma, kappa = kappa + sigma = sigma, kappa = kappa, zi = zi, + zi_betas = zi_betas ) if (!(all(sapply(res, is.character)))) { stop("All arguments must be supplied as character vectors") @@ -107,11 +117,11 @@ print.jsdmprior <- function(x, ...) { rep("site_intercept", 3), rep("mglmm", 3), rep("gllvm", 3), - "gaussian", "neg_binomial" + "gaussian", "neg_binomial",rep("zero_inflation",2) ), Constraint = c( "lower=0", rep("none", 5), rep("lower=0", 2), - rep("none", 4), rep("lower=0", 3) + rep("none", 4), rep("lower=0", 3),"lower=0,upper=1","none" ), Prior = unlist(unname(x)) ) diff --git a/R/sim_data_funs.R b/R/sim_data_funs.R index 44b151d..2b52a2d 100644 --- a/R/sim_data_funs.R +++ b/R/sim_data_funs.R @@ -37,7 +37,9 @@ #' #' @param family is the response family, must be one of \code{"gaussian"}, #' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -#' or \code{"bernoulli"}. Regular expression matching is supported. +#' \code{"bernoulli"}, \code{"zi_poisson"}, or +#' \code{"zi_neg_binomial"}. Regular expression +#' matching is supported. #' #' @param method is the jSDM method to use, currently either \code{"gllvm"} or #' \code{"mglmm"} - see details for more information. @@ -57,17 +59,34 @@ #' @param beta_param The parameterisation of the environmental covariate #' effects, by default \code{"unstruct"}. See details for further information. #' +#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter +#' is a species-specific constant (default, \code{"constant"}), or varies by +#' environmental covariates (\code{"covariate"}). +#' +#' @param zi_k If \code{zi="covariate"}, the number of environmental covariates +#' that the zero-inflation parameter responds to. The default (\code{NULL}) is +#' that the zero-inflation parameter responds to exactly the same covariate matrix +#' as the mean parameter. Otherwise, a different set of random environmental +#' covariates are generated, plus an intercept (not included in zi_k) and used +#' to predict zero-inflation +#' #' @param prior Set of prior specifications from call to [jsdm_prior()] jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "mglmm"), species_intercept = TRUE, Ntrials = NULL, site_intercept = "none", beta_param = "unstruct", + zi_param = "constant", zi_k = NULL, prior = jsdm_prior()) { response <- match.arg(family, c("gaussian", "neg_binomial", "poisson", - "bernoulli", "binomial")) + "bernoulli", "binomial", "zi_poisson", + "zi_neg_binomial")) site_intercept <- match.arg(site_intercept, c("none","ungrouped","grouped")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) + zi_param <- match.arg(zi_param, c("constant","covariate")) + if(missing(method)){ + stop("method argument needs to be specified") + } if(site_intercept == "grouped"){ stop("Grouped site intercept not supported") } @@ -99,6 +118,18 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m Ntrials <- ntrials_check(Ntrials = Ntrials, N = N) } + if (!is.null(zi_k)) { + if(zi_k < 1 | zi_k %% 1 != 0){ + stop("zi_k must be either NULL or a positive integer") + } + } + + if(is.null(zi_k)){ + ZI_K <- K + } else{ + ZI_K <- zi_k + } + # prior object breakdown prior_split <- lapply(prior, strsplit, split = "\\(|\\)|,") if (!all(sapply(prior_split, function(x) { @@ -109,7 +140,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "lkj_corr", "student_t", "cauchy", - "gamma" + "gamma", + "beta" ) }))) { stop(paste( @@ -132,7 +164,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "lkj_corr" = "rlkj", "student_t" = "rstudentt", "cauchy" = "rcauchy", - "gamma" = "rgamma" + "gamma" = "rgamma", + "beta" = "rbeta" ) fun_arg1 <- switch(x, "sigmas_preds" = K + 1 * species_intercept, @@ -149,7 +182,9 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "L" = D1 * (S - D1) + (D1 * (D1 - 1) / 2) + D1, "sigma_L" = 1, "sigma" = S, - "kappa" = S + "kappa" = S, + "zi" = S, + "zi_betas" = S*(ZI_K+1) ) fun_args <- as.list(c(fun_arg1, as.numeric(unlist(y[[1]][[1]])[-1]))) @@ -175,12 +210,12 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m # now do covariates - if K = NULL then do intercept only if (K == 0) { x <- matrix(1, nrow = N, ncol = 1) - colnames(x) <- "Intercept" + colnames(x) <- "(Intercept)" J <- 1 } else if (isTRUE(species_intercept)) { x <- matrix(stats::rnorm(N * K), ncol = K, nrow = N) colnames(x) <- paste0("V", 1:K) - x <- cbind(Intercept = 1, x) + x <- cbind("(Intercept)" = 1, x) J <- K + 1 } else if (isFALSE(species_intercept)) { x <- matrix(stats::rnorm(N * K), ncol = K, nrow = N) @@ -260,9 +295,11 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m L <- matrix(nrow = S, ncol = D) idx2 <- 0 - for (i in 1:(D - 1)) { - for (j in (i + 1):(D)) { - L[i, j] <- 0 + if(D > 1){ + for (i in 1:(D - 1)) { + for (j in (i + 1):(D)) { + L[i, j] <- 0 + } } } for (j in 1:D) { @@ -296,9 +333,50 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m match.fun(prior_func[["kappa"]][[1]]), prior_func[["kappa"]][[2]] )) + } else if (response == "zi_poisson") { + if(zi_param == "covariate"){ + zi_betas <- matrix(do.call( + match.fun(prior_func[["zi_betas"]][[1]]), + prior_func[["zi_betas"]][[2]] + ), ncol = S) + } else { + zi <- do.call( + match.fun(prior_func[["zi"]][[1]]), + prior_func[["zi"]][[2]] + ) + } + } else if (response == "zi_neg_binomial") { + if(zi_param == "covariate"){ + zi_betas <- matrix(do.call( + match.fun(prior_func[["zi_betas"]][[1]]), + prior_func[["zi_betas"]][[2]] + ), ncol = S) + } else { + zi <- do.call( + match.fun(prior_func[["zi"]][[1]]), + prior_func[["zi"]][[2]] + ) + } + kappa <- abs(do.call( + match.fun(prior_func[["kappa"]][[1]]), + prior_func[["kappa"]][[2]] + )) } # print(str(sigma)) + # zero-inflation in case of covariates + if(grepl("zi_", response) & zi_param == "covariate"){ + if(is.null(zi_k)){ + zi_X <- x + } else { + zi_X <- matrix(stats::rnorm(N * zi_k), ncol = zi_k, nrow = N) + colnames(zi_X) <- paste0("V", 1:zi_k) + zi_X <- cbind("(Intercept)" = 1, zi_X) + } + zi <- inv_logit(zi_X %*% zi_betas) + } + + # generate Y Y <- matrix(nrow = N, ncol = S) for (i in 1:N) { for (j in 1:S) { @@ -316,11 +394,24 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m "gaussian" = stats::rnorm(1, mu_ij, sigma[j]), "poisson" = stats::rpois(1, exp(mu_ij)), "bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij)), - "binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij)) + "binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij)), + "zi_poisson" = (1-stats::rbinom( + 1, 1, + ifelse(zi_param == "covariate", + zi[i,j],zi[j])))*stats::rpois(1, exp(mu_ij)), + "zi_neg_binomial" = (1-stats::rbinom( + 1, 1, + ifelse(zi_param == "covariate", + zi[i,j],zi[j])))*rgampois(1, mu = exp(mu_ij), scale = kappa[j]) ) } } - # print(str(Y)) + + if(any(apply(Y, 2, function(x) all(x == 0)))){ + message(paste("Y contains an entirely empty column, which will not work for", + "jsdm fitting, it is recommended that the simulation is run again.")) + } + pars <- list( @@ -362,6 +453,16 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m if(response == "neg_binomial"){ pars$kappa <- kappa } + if(grepl("zi_", response)){ + if(zi_param == "constant"){ + pars$zi <- zi + } else if(zi_param == "covariate"){ + pars$zi_betas <- zi_betas + } + } + if(response == "zi_neg_binomial"){ + pars$kappa <- kappa + } if (isTRUE(species_intercept)) { if (K > 0) { x <- x[, 2:ncol(x)] @@ -375,6 +476,10 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m if(response == "binomial"){ output$Ntrials <- Ntrials } + if(grepl("zi_", response) & zi_param == "covariate"){ + output$zi_k <- ZI_K + 1 + output$zi_X <- zi_X + } return(output) } diff --git a/R/stan_jsdm.R b/R/stan_jsdm.R index c8c91e9..d32a56e 100644 --- a/R/stan_jsdm.R +++ b/R/stan_jsdm.R @@ -11,6 +11,13 @@ #' \code{"unstruct"} parameterisation all covariate effects are assumed to draw #' from a simple distribution with no correlation structure. Both parameterisations #' can be modified using the prior object. +#' Families supported are the Gaussian family, the negative binomial family, +#' the Poisson family, the binomial family (with number of trials specificied +#' using the \code{Ntrials} parameter), the Bernoulli family (the special case +#' of the binomial family where number of trials is equal to one), the +#' zero-inflated Poisson and the zero-inflated negative binomial. For both +#' zero-inflated families the zero-inflation is assumed to be a species-specific +#' constant. #' #' @param formula The formula of covariates that the species means are modelled from #' @@ -26,9 +33,10 @@ #' #' @param D The number of latent variables within a GLLVM model #' -#' @param family The response family for the model, required to be one of -#' \code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -#' or \code{"neg_binomial"} +#' @param family is the response family, must be one of \code{"gaussian"}, +#' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +#' \code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression +#' matching is supported. #' #' @param species_intercept Whether the model should be fit with an intercept by #' species, by default \code{TRUE} @@ -63,6 +71,15 @@ #' @param beta_param The parameterisation of the environmental covariate effects, by #' default \code{"unstruct"}. See details for further information. #' +#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter +#' is a species-specific constant (default, \code{"constant"}), or varies by +#' environmental covariates (\code{"covariate"}). +#' +#' @param zi_X If \code{zi = "covariate"}, the matrix of environmental predictors +#' that the zero-inflation is modelled in response to. If there is not already +#' an intercept column (identified by all values being equal to one), one will +#' be added to the front of the matrix. +#' #' @param ... Arguments passed to [rstan::sampling()] #' #' @return A \code{jsdmStanFit} object, comprising a list including the StanFit @@ -103,10 +120,22 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth dat_list = NULL, family, site_intercept = "none", D = NULL, prior = jsdm_prior(), site_groups = NULL, beta_param = "unstruct", Ntrials = NULL, + zi_param = "constant", zi_X = NULL, save_data = TRUE, iter = 4000, log_lik = TRUE, ...) { family <- match.arg(family, c("gaussian", "bernoulli", "poisson", - "neg_binomial","binomial")) + "neg_binomial","binomial", "zi_poisson", + "zi_neg_binomial")) beta_param <- match.arg(beta_param, c("cor", "unstruct")) + zi_param <- match.arg(zi_param, c("constant","covariate")) + if(grepl("zi", family) & zi_param == "covariate"){ + if(is.null(zi_X) & is.null(dat_list)){ + message("If zi_param = 'covariate' and no zi_X matrix is supplied then the X matrix is used") + zi_X <- X + } + } else{ + zi_X <- NULL + } + stopifnot( is.logical(species_intercept), @@ -120,7 +149,8 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth Y = Y, X = X, species_intercept = species_intercept, D = D, site_intercept = site_intercept, site_groups = site_groups, dat_list = dat_list, - family = family, method = method, Ntrials = Ntrials + family = family, method = method, Ntrials = Ntrials, + zi_X = zi_X ) # Create stancode @@ -128,7 +158,7 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth family = family, method = method, prior = prior, log_lik = log_lik, site_intercept = site_intercept, - beta_param = beta_param + beta_param = beta_param, zi_param = zi_param ) # Compile model @@ -226,7 +256,8 @@ stan_gllvm.formula <- function(formula, data = list(), ...) { validate_data <- function(Y, D, X, species_intercept, dat_list, family, site_intercept, - method, site_groups, Ntrials) { + method, site_groups, Ntrials, + zi_X) { method <- match.arg(method, c("gllvm", "mglmm")) # do things if data not given as list: @@ -264,6 +295,16 @@ validate_data <- function(Y, D, X, species_intercept, colnames(X)[1] <- "(Intercept)" } } + if(grepl("zi", family) & !is.null(zi_X)){ + if(is.null(colnames(zi_X))){ + message("No column names specified for zi_X, assigning names") + colnames(zi_X) <- paste0("V",seq_len(ncol(zi_X))) + } + if(!any(apply(zi_X, 2, function(x) all(x == 1)))){ + zi_X <- cbind("(Intercept)" = 1, zi_X) + } + zi_k <- ncol(zi_X) + } if(site_intercept == "grouped"){ if(length(site_groups) != N) @@ -290,6 +331,14 @@ validate_data <- function(Y, D, X, species_intercept, if(family == "binomial"){ data_list$Ntrials <- Ntrials } + if(grepl("zi_",family) & !is.null(zi_X)){ + data_list$zi_k <- zi_k + data_list$zi_X <- zi_X + + if(nrow(zi_X) != N){ + stop("Number of rows of zi_X must be equal to number of rows of X") + } + } } else { if (!all(c("Y", "K", "S", "N", "X") %in% names(dat_list))) { stop("If supplying data as a list must have entries Y, K, S, N, X") @@ -305,6 +354,12 @@ validate_data <- function(Y, D, X, species_intercept, } } + if (grepl("zi_", family) & !is.null(zi_X)) { + if (!all(c("zi_X","zi_k") %in% names(dat_list))) { + stop("Zero-inflated models with the covariate parameterisation of zi require zi_X and zi_k in dat_list") + } + } + if (site_intercept == "grouped") { if (!all(c("ngrp","grps") %in% names(dat_list))) { stop("Grouped site intercept models require ngrp and grps in dat_list") @@ -341,17 +396,39 @@ validate_data <- function(Y, D, X, species_intercept, ))) { stop("Y matrix is not binary") } - } else if (family %in% c("poisson", "neg_binomial", "binomial")) { + } else if (family %in% c("poisson", "neg_binomial", "binomial", + "zi_poisson", "zi_neg_binomial")) { if (!any(apply(data_list$Y, 1:2, is.wholenumber))) { stop("Y matrix is not composed of integers") } } + # check to make sure no completely blank columns in Y + if(any(apply(data_list$Y, 2, function(x) all(x == 0)))){ + stop("Y contains an empty column, which cannot work for this model") + } + # Check if Ntrials is appropriate given if(identical(family, "binomial")) { data_list$Ntrials <- ntrials_check(data_list$Ntrials, data_list$N) } + # create zeros/non-zeros for zero-inflated poisson + if(grepl("zi_",family)) { + if(any(apply(data_list$Y, 2, min)>0)){ + stop("Zero-inflated distributions require zeros to be present in all Y values.") + } + data_list$N_zero <- colSums(data_list$Y==0) + data_list$N_nonzero <- colSums(data_list$Y>0) + data_list$Sum_nonzero <- sum(data_list$N_nonzero) + data_list$Sum_zero <- sum(data_list$N_zero) + data_list$Y_nz <- c(as.matrix(data_list$Y))[c(as.matrix(data_list$Y))>0] + data_list$nn <- rep(1:data_list$N,data_list$S)[c(data_list$Y>0)] + data_list$ss <- rep(1:data_list$S,each=data_list$N)[c(data_list$Y>0)] + data_list$nz <- rep(1:data_list$N,data_list$S)[c(data_list$Y==0)] + data_list$sz <- rep(1:data_list$S,each=data_list$N)[c(data_list$Y==0)] + } + return(data_list) } @@ -381,11 +458,29 @@ model_to_jsdmstanfit <- function(model_fit, method, data_list, species_intercept } else { list() } + fam <- list(family = family, + params = switch(family, + "gaussian" = "sigma", + "bernoulli" = character(), + "poisson" = character(), + "neg_binomial" = "kappa", + "binomial" = character(), + "zi_poisson" = "zi", + "zi_neg_binomial" = c("kappa","zi")), + params_dataresp= character(), + preds = character(), + data_list = list()) + class(fam) <- "jsdmStanFamily" + if(isTRUE(save_data) & ("zi_X" %in% names(data_list))){ + fam$params_dataresp <- "zi" + fam$preds <- colnames(data_list$zi_X) + fam$data_list <- list(zi_X = data_list$zi_X) + } model_output <- list( fit = model_fit, jsdm_type = method, - family = family, + family = fam, species = species, sites = sites, preds = preds, diff --git a/R/update.R b/R/update.R index e735119..183e3d4 100644 --- a/R/update.R +++ b/R/update.R @@ -12,6 +12,10 @@ #' @param newD New number of latent variables, by default \code{NULL} #' @param newNtrials New number of trials (binomial model only), by default #' \code{NULL} +#' @param newZi_X New predictor data for the zi parameter in zero-inflated models, +#' by default \code{NULL}. In cases where the model was originally fit with the +#' same X and zi_X data and only newX is supplied to update.jsdmStanFit the zi_X +#' data will also be set to newX. #' @param save_data Whether to save the data in the jsdmStanFit object, by default #' \code{TRUE} #' @param ... Arguments passed to [rstan::sampling()] @@ -51,7 +55,7 @@ #' gllvm_fit2 #' } update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, - newNtrials = NULL, + newNtrials = NULL, newZi_X = NULL, save_data = TRUE, ...) { if (length(object$data_list) == 0) { stop("Update requires the original data to be saved in the model object") @@ -72,7 +76,7 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, } else { Y <- newY } - family <- object$family + family <- object$family$family method <- object$jsdm_type if(!is.null(newD)){ D <- newD @@ -86,6 +90,19 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, Ntrials <- object$data_list$Ntrials } } + if ("zi" %in% object$family$params_dataresp){ + if(is.null(newZi_X)) { + if(isTRUE(all.equal(object$data_list$X, object$family$data_list$zi_X)) & !is.null(newX)){ + zi_X <- newX + } else{ + zi_X <- object$family$data_list$zi_X + } + } else { + zi_X <- newZi_X + } + } else{ + zi_X <- NULL + } species_intercept <- "(Intercept)" %in% colnames(object$data_list$X) @@ -100,7 +117,8 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL, Y = Y, X = X, species_intercept = species_intercept, D = D, site_intercept = site_intercept, site_groups = site_groups, dat_list = NULL, - family = family, method = method, Ntrials = Ntrials + family = family, method = method, Ntrials = Ntrials, + zi_X = zi_X ) # get original stan model diff --git a/README.md b/README.md index 815876c..c76b811 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,24 @@ [![R-CMD-check](https://github.com/NERC-CEH/jsdmstan/workflows/R-CMD-check/badge.svg)](https://github.com/NERC-CEH/jsdmstan/actions) [![Codecov test coverage](https://codecov.io/gh/NERC-CEH/jsdmstan/branch/main/graph/badge.svg)](https://codecov.io/gh/NERC-CEH/jsdmstan?branch=main) +[![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) This is an R package for running joint Species Distribution Models (jSDM) in [Stan](https://mc-stan.org/). jSDMs are models where multiple response variables (i.e. species) are fit at the same time, and the covariance between these species are used to inform the model results. For a review of jSDMs see Warton et al. (2015) So many variables: joint modelling in community ecology. *TREE*, 30:766-779 DOI: [10.1016/j.tree.2015.09.007](http://doi.org/10.1016/j.tree.2015.09.007). This package can fit data to a Multivariate Generalised Linear Mixed Model (MGLMM) or a Generalised Linear Latent Variable Model (GLLVM), and also provides functionality for simulating data under these scenarios and an interface to the [bayesplot](https://mc-stan.org/bayesplot/) package for a wide variety of plotting options. +# Installation + +This package can be installed using the [remotes](https://remotes.r-lib.org/index.html) package using the following code: + +``` +# install.packages("remotes") +remotes::install_github("NERC-CEH/jsdmstan") +``` + +# Using jsdmstan + Example code: ``` diff --git a/man/jsdmStanFamily.Rd b/man/jsdmStanFamily.Rd new file mode 100644 index 0000000..eaf6283 --- /dev/null +++ b/man/jsdmStanFamily.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/jsdmstan-families.R +\name{jsdmStanFamily} +\alias{jsdmStanFamily} +\alias{jsdmStanFamily_empty} +\title{jsdmStanFamily class} +\usage{ +jsdmStanFamily_empty() +} +\description{ +This is the jsdmStanFamily class, which occupies a slot within any +jsdmStanFit object. +} +\section{Elements for \code{jsdmStanFamily} objects}{ + +\describe{ +\item{\code{family}}{ +A length one character vector describing family used to fit object. Options +are \code{"gaussian"}, \code{"poisson"}, \code{"bernoulli"}, +\code{"neg_binomial"}, \code{"binomial"}, \code{"zi_poisson"}, +\code{"zi_neg_binomial"}, or \code{"multiple"}. +} +\item{\code{params}}{ +A character vector that includes all the names of the family-specific parameters. +} +\item{\code{params_dataresp}}{ +A character vector that includes any named family-specific parameters that are +modelled in response to data. +} +\item{\code{preds}}{ +A character vector of the measured predictors included if family parameters +are modelled in response to data. If family parameters are not modelled in +response to data this is left empty. +} +\item{\code{data_list}}{ +A list containing the original data used to fit the model +(empty when save_data is set to \code{FALSE} or family parameters are not +modelled in response to data). +} +} +} + diff --git a/man/jsdmStanFit.Rd b/man/jsdmStanFit.Rd index f8d2eae..0c66f0e 100644 --- a/man/jsdmStanFit.Rd +++ b/man/jsdmStanFit.Rd @@ -17,7 +17,7 @@ This is the jsdmStanFit class, which stan_gllvm and stan_mglmm both create. A length one character vector describing type of jSDM } \item{\code{family}}{ -A character vector describing response family +A jsdmStanFamily object describing characteristics of family } \item{\code{species}}{ A character vector of the species names diff --git a/man/jsdm_prior.Rd b/man/jsdm_prior.Rd index 33d21b9..d122f96 100644 --- a/man/jsdm_prior.Rd +++ b/man/jsdm_prior.Rd @@ -20,7 +20,9 @@ jsdm_prior( L = "normal(0,1)", sigma_L = "normal(0,1)", sigma = "normal(0,1)", - kappa = "normal(0,1)" + kappa = "normal(0,1)", + zi = "beta(1,1)", + zi_betas = "normal(0,1)" ) \method{print}{jsdmprior}(x, ...) @@ -69,6 +71,15 @@ to be positive (default standard normal)} \item{kappa}{For negative binomial response, the negative binomial variance parameter. Constrained to be positive (default standard normal)} +\item{zi}{For zero-inflated poisson or negative binomial with no environmental +covariate effects upon the zero-inflation, the proportion of +inflated zeros (default beta distribution with both alpha and beta parameters +set to 1).} + +\item{zi_betas}{For zero-inflated poisson or negative binomial with +environmental effects upon the zero-inflation, the covariate effects on the +zero-inflation on the logit scale} + \item{x}{Object of class \code{jsdmprior}} \item{...}{Currently unused} diff --git a/man/jsdm_sim_data.Rd b/man/jsdm_sim_data.Rd index 30fa5df..4538ed2 100644 --- a/man/jsdm_sim_data.Rd +++ b/man/jsdm_sim_data.Rd @@ -17,6 +17,8 @@ jsdm_sim_data( Ntrials = NULL, site_intercept = "none", beta_param = "unstruct", + zi_param = "constant", + zi_k = NULL, prior = jsdm_prior() ) @@ -35,7 +37,9 @@ mglmm_sim_data(...) \item{family}{is the response family, must be one of \code{"gaussian"}, \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, -or \code{"bernoulli"}. Regular expression matching is supported.} +\code{"bernoulli"}, \code{"zi_poisson"}, or +\code{"zi_neg_binomial"}. Regular expression +matching is supported.} \item{method}{is the jSDM method to use, currently either \code{"gllvm"} or \code{"mglmm"} - see details for more information.} @@ -55,6 +59,17 @@ supported currently.} \item{beta_param}{The parameterisation of the environmental covariate effects, by default \code{"unstruct"}. See details for further information.} +\item{zi_param}{For the zero-inflated families, whether the zero-inflation parameter +is a species-specific constant (default, \code{"constant"}), or varies by +environmental covariates (\code{"covariate"}).} + +\item{zi_k}{If \code{zi="covariate"}, the number of environmental covariates +that the zero-inflation parameter responds to. The default (\code{NULL}) is +that the zero-inflation parameter responds to exactly the same covariate matrix +as the mean parameter. Otherwise, a different set of random environmental +covariates are generated, plus an intercept (not included in zi_k) and used +to predict zero-inflation} + \item{prior}{Set of prior specifications from call to \code{\link[=jsdm_prior]{jsdm_prior()}}} \item{...}{Arguments passed to jsdm_sim_data} diff --git a/man/jsdm_stancode.Rd b/man/jsdm_stancode.Rd index a37f910..4ff0eec 100644 --- a/man/jsdm_stancode.Rd +++ b/man/jsdm_stancode.Rd @@ -11,7 +11,8 @@ jsdm_stancode( prior = jsdm_prior(), log_lik = TRUE, site_intercept = "none", - beta_param = "cor" + beta_param = "cor", + zi_param = "constant" ) \method{print}{jsdmstan_model}(x, ...) @@ -19,8 +20,11 @@ jsdm_stancode( \arguments{ \item{method}{The method, one of \code{"gllvm"} or \code{"mglmm"}} -\item{family}{The family, one of "\code{"gaussian"}, \code{"bernoulli"}, -\code{"poisson"} or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, \code{"zi_poisson"}, or +\code{"zi_neg_binomial"}. Regular expression +matching is supported.} \item{prior}{The prior, given as the result of a call to \code{\link[=jsdm_prior]{jsdm_prior()}}} @@ -35,6 +39,10 @@ grouping)} \item{beta_param}{The parameterisation of the environmental covariate effects, by default \code{"cor"}. See details for further information.} +\item{zi_param}{For the zero-inflated families, whether the zero-inflation parameter +is a species-specific constant (default, \code{"constant"}), or varies by +environmental covariates (\code{"covariate"}).} + \item{x}{The jsdm_stancode object} \item{...}{Currently unused} diff --git a/man/loo.jsdmStanFit.Rd b/man/loo.jsdmStanFit.Rd index 01b0f68..53c024b 100644 --- a/man/loo.jsdmStanFit.Rd +++ b/man/loo.jsdmStanFit.Rd @@ -19,5 +19,6 @@ A list with class \code{c("psis_loo","loo")}, as detailed in the \description{ This function uses the \pkg{loo} package to compute PSIS-LOO CV, efficient approximate leave-one-out (LOO) cross-validation for Bayesian models using Pareto -smoothed importance sampling (PSIS). +smoothed importance sampling (PSIS). This requires that the model was fit using +\code{log_lik = TRUE}. } diff --git a/man/posterior_linpred.jsdmStanFit.Rd b/man/posterior_linpred.jsdmStanFit.Rd index 437112b..bda65e2 100644 --- a/man/posterior_linpred.jsdmStanFit.Rd +++ b/man/posterior_linpred.jsdmStanFit.Rd @@ -11,7 +11,6 @@ newdata = NULL, ndraws = NULL, draw_ids = NULL, - newdata_type = "X", list_index = "draws", ... ) @@ -31,9 +30,6 @@ number of samples.} \item{draw_ids}{The IDs of the draws to be used, as a numeric vector} -\item{newdata_type}{What form is the new data in, at the moment only -supplying covariates is supported.} - \item{list_index}{Whether to return the output list indexed by the number of draws (default), species, or site.} @@ -46,7 +42,8 @@ the list being a site x species matrix. If the list_index is \code{"species"} th list will have length equal to the number of species with each element of the list being a draws x sites matrix. If the list_index is \code{"sites"} the list will have length equal to the number of sites with each element of the -list being a draws x species matrix. +list being a draws x species matrix. Note that in the zero-inflated case this is +only the linear predictor of the non-zero-inflated part of the model. } \description{ Extract the posterior draws of the linear predictor, possibly transformed by diff --git a/man/posterior_predict.jsdmStanFit.Rd b/man/posterior_predict.jsdmStanFit.Rd index ca501f3..2c5ec7f 100644 --- a/man/posterior_predict.jsdmStanFit.Rd +++ b/man/posterior_predict.jsdmStanFit.Rd @@ -8,11 +8,11 @@ \method{posterior_predict}{jsdmStanFit}( object, newdata = NULL, - newdata_type = "X", ndraws = NULL, draw_ids = NULL, list_index = "draws", Ntrials = NULL, + include_zi = TRUE, ... ) } @@ -21,9 +21,6 @@ \item{newdata}{New data, by default \code{NULL} and uses original data} -\item{newdata_type}{What form is the new data in, at the moment only -supplying covariates is supported.} - \item{ndraws}{Number of draws, by default the number of samples in the posterior. Will be sampled randomly from the chains if fewer than the number of samples.} @@ -33,20 +30,23 @@ number of samples.} \item{list_index}{Whether to return the output list indexed by the number of draws (default), species, or site.} -\item{Ntrials}{For the binomial distribution the number of trials, given as -either a single integer which is assumed to be constant across sites or as -a site-length vector of integers.} +\item{Ntrials}{For the binomial distribution the number of trials, given as either +a single integer which is assumed to be constant across sites or as a site-length +vector of integers.} + +\item{include_zi}{For the zero-inflated poisson distribution, whether to include +the zero-inflation in the prediction. Defaults to \code{TRUE}.} \item{...}{Currently unused} } \value{ A list of linear predictors. If list_index is \code{"draws"} (the default) -the list will have length equal to the number of draws with each element of -the list being a site x species matrix. If the list_index is \code{"species"} the -list will have length equal to the number of species with each element of -the list being a draws x sites matrix. If the list_index is \code{"sites"} the -list will have length equal to the number of sites with each element of -the list being a draws x species matrix. +the list will have length equal to the number of draws with each element of the +list being a site x species matrix. If the list_index is \code{"species"} the +list will have length equal to the number of species with each element of the +list being a draws x sites matrix. If the list_index is \code{"sites"} the list +will have length equal to the number of sites with each element of the list being +a draws x species matrix. } \description{ Draw from the posterior predictive distribution of the outcome. diff --git a/man/posterior_zipred.Rd b/man/posterior_zipred.Rd new file mode 100644 index 0000000..a65874a --- /dev/null +++ b/man/posterior_zipred.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_predict.R +\name{posterior_zipred} +\alias{posterior_zipred} +\title{Access the posterior distribution of the linear predictor for zero-inflation +parameter} +\usage{ +posterior_zipred( + object, + transform = FALSE, + newdata = NULL, + ndraws = NULL, + draw_ids = NULL, + list_index = "draws" +) +} +\arguments{ +\item{object}{The model object} + +\item{transform}{Should the linear predictor be transformed using the +inverse-link function. The default is \code{FALSE}, in which case the +untransformed linear predictor is returned.} + +\item{newdata}{New data, by default \code{NULL} and uses original data} + +\item{ndraws}{Number of draws, by default the number of samples in the +posterior. Will be sampled randomly from the chains if fewer than the +number of samples.} + +\item{draw_ids}{The IDs of the draws to be used, as a numeric vector} + +\item{list_index}{Whether to return the output list indexed by the number of +draws (default), species, or site.} +} +\value{ +A list of linear predictors. If list_index is \code{"draws"} (the +default) the list will have length equal to the number of draws with each +element of the list being a site x species matrix. If the list_index is +\code{"species"} the list will have length equal to the number of species +with each element of the list being a draws x sites matrix. If the +list_index is \code{"sites"} the list will have length equal to the number +of sites with each element of the list being a draws x species matrix. Note +that in the zero-inflated case this is only the linear predictor of the +non-zero-inflated part of the model. +} +\description{ +Extract the posterior draws of the linear predictor for the zero-inflation +parameter, possibly transformed by the inverse-link function. +} +\seealso{ +\code{\link[=posterior_predict.jsdmStanFit]{posterior_predict.jsdmStanFit()}} +} diff --git a/man/print.jsdmStanFamily.Rd b/man/print.jsdmStanFamily.Rd new file mode 100644 index 0000000..89259eb --- /dev/null +++ b/man/print.jsdmStanFamily.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/jsdmstan-families.R +\name{print.jsdmStanFamily} +\alias{print.jsdmStanFamily} +\title{Print jsdmStanFamily object} +\usage{ +\method{print}{jsdmStanFamily}(x, ...) +} +\arguments{ +\item{x}{A jsdmStanFamily object} + +\item{...}{Other arguments, not used at this stage.} +} +\description{ +Print jsdmStanFamily object +} diff --git a/man/stan_gllvm.Rd b/man/stan_gllvm.Rd index 37ede1f..9578b01 100644 --- a/man/stan_gllvm.Rd +++ b/man/stan_gllvm.Rd @@ -42,9 +42,10 @@ species, by default \code{TRUE}} Y, X, N, S, K, and site_intercept. See output of \code{\link[=jsdm_sim_data]{jsdm_sim_data()}} for an example of how this can be formatted.} -\item{family}{The response family for the model, required to be one of -\code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression +matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept diff --git a/man/stan_jsdm.Rd b/man/stan_jsdm.Rd index 478f0f1..426bf35 100644 --- a/man/stan_jsdm.Rd +++ b/man/stan_jsdm.Rd @@ -21,6 +21,8 @@ stan_jsdm(X, ...) site_groups = NULL, beta_param = "unstruct", Ntrials = NULL, + zi_param = "constant", + zi_X = NULL, save_data = TRUE, iter = 4000, log_lik = TRUE, @@ -47,9 +49,10 @@ species, by default \code{TRUE}} Y, X, N, S, K, and site_intercept. See output of \code{\link[=jsdm_sim_data]{jsdm_sim_data()}} for an example of how this can be formatted.} -\item{family}{The response family for the model, required to be one of -\code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression +matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept @@ -70,6 +73,15 @@ default \code{"unstruct"}. See details for further information.} either a single integer which is assumed to be constant across sites or as a site-length vector of integers.} +\item{zi_param}{For the zero-inflated families, whether the zero-inflation parameter +is a species-specific constant (default, \code{"constant"}), or varies by +environmental covariates (\code{"covariate"}).} + +\item{zi_X}{If \code{zi = "covariate"}, the matrix of environmental predictors +that the zero-inflation is modelled in response to. If there is not already +an intercept column (identified by all values being equal to one), one will +be added to the front of the matrix.} + \item{save_data}{If the data used to fit the model should be saved in the model object, by default TRUE.} @@ -100,6 +112,13 @@ to be constrained by a correlation matrix between the covariates. With the \code{"unstruct"} parameterisation all covariate effects are assumed to draw from a simple distribution with no correlation structure. Both parameterisations can be modified using the prior object. +Families supported are the Gaussian family, the negative binomial family, +the Poisson family, the binomial family (with number of trials specificied +using the \code{Ntrials} parameter), the Bernoulli family (the special case +of the binomial family where number of trials is equal to one), the +zero-inflated Poisson and the zero-inflated negative binomial. For both +zero-inflated families the zero-inflation is assumed to be a species-specific +constant. } \section{Methods (by class)}{ \itemize{ diff --git a/man/stan_mglmm.Rd b/man/stan_mglmm.Rd index a25bb1a..05f8413 100644 --- a/man/stan_mglmm.Rd +++ b/man/stan_mglmm.Rd @@ -39,9 +39,10 @@ species, by default \code{TRUE}} Y, X, N, S, K, and site_intercept. See output of \code{\link[=jsdm_sim_data]{jsdm_sim_data()}} for an example of how this can be formatted.} -\item{family}{The response family for the model, required to be one of -\code{"gaussian"}, \code{"bernoulli"}, \code{"poisson"}, \code{"binomial"} -or \code{"neg_binomial"}} +\item{family}{is the response family, must be one of \code{"gaussian"}, +\code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"}, +\code{"bernoulli"}, or \code{"zi_poisson"}. Regular expression +matching is supported.} \item{site_intercept}{Whether a site intercept should be included, potential values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept diff --git a/man/update.jsdmStanFit.Rd b/man/update.jsdmStanFit.Rd index c167b5f..eeb8a56 100644 --- a/man/update.jsdmStanFit.Rd +++ b/man/update.jsdmStanFit.Rd @@ -10,6 +10,7 @@ newX = NULL, newD = NULL, newNtrials = NULL, + newZi_X = NULL, save_data = TRUE, ... ) @@ -26,6 +27,11 @@ \item{newNtrials}{New number of trials (binomial model only), by default \code{NULL}} +\item{newZi_X}{New predictor data for the zi parameter in zero-inflated models, +by default \code{NULL}. In cases where the model was originally fit with the +same X and zi_X data and only newX is supplied to update.jsdmStanFit the zi_X +data will also be set to newX.} + \item{save_data}{Whether to save the data in the jsdmStanFit object, by default \code{TRUE}} diff --git a/tests/testthat/test-jsdm_stancode.R b/tests/testthat/test-jsdm_stancode.R new file mode 100644 index 0000000..39feae7 --- /dev/null +++ b/tests/testthat/test-jsdm_stancode.R @@ -0,0 +1,22 @@ +# jsdm_stancode checks +test_that("jsdm_stancode errors appropriately", { + expect_error(jsdm_stancode(method = "mglmm", family = "nothing")) + + expect_error(jsdm_stancode(method = "mglmm", family = "poisson", + prior = list("betas" = "normal(0,1)")), + "Prior must be given as a jsdmprior object") +}) + +jsdm_code <- jsdm_stancode(method = "mglmm", family = "zi_neg_binomial", + log_lik = TRUE, site_intercept = "grouped", + zi_param = "covariate") + +test_that("jsdm_stancode print works", { + expect_output(print(jsdm_code)) +}) + + +# jsdmStanFamily checks +test_that("jsdmStanFamily print works", { + expect_output(print(jsdmStanFamily_empty())) +}) diff --git a/tests/testthat/test-posterior_predict.R b/tests/testthat/test-posterior_predict.R index 6fd3697..8db03c8 100644 --- a/tests/testthat/test-posterior_predict.R +++ b/tests/testthat/test-posterior_predict.R @@ -8,11 +8,6 @@ suppressWarnings(bern_fit <- stan_gllvm( )) test_that("posterior linpred errors appropriately", { - expect_error( - posterior_linpred(bern_fit, newdata_type = "F"), - "Currently only data on covariates is supported." - ) - expect_error(posterior_linpred(bern_fit, list_index = "bored")) @@ -29,10 +24,6 @@ test_that("posterior linpred errors appropriately", { }) test_that("posterior predictive errors appropriately", { - expect_error( - posterior_predict(bern_fit, newdata_type = "F"), - "Currently only data on covariates is supported." - ) expect_error( posterior_predict(bern_fit, draw_ids = c(-5,-1,0)), @@ -128,7 +119,7 @@ bino_sim_data <- gllvm_sim_data(N = 100, S = 9, K = 2, family = "binomial", Ntrials = 20) bino_pred_data <- matrix(rnorm(100 * 2), nrow = 100) colnames(bino_pred_data) <- c("V1", "V2") -suppressWarnings(bino_fit <- stan_mglmm( +suppressWarnings(bino_fit <- stan_gllvm( dat_list = bino_sim_data, family = "binomial", refresh = 0, chains = 2, iter = 500 )) @@ -150,3 +141,60 @@ test_that("posterior_(lin)pred works with gllvm and bino", { expect_false(any(sapply(bino_pred2, function(x) x < 0))) expect_false(any(sapply(bino_pred2, function(x) x > 16))) }) + +set.seed(86738873) +zip_sim_data <- gllvm_sim_data(N = 100, S = 7, K = 2, family = "zi_poisson", + site_intercept = "ungrouped", D = 1) +zip_pred_data <- matrix(rnorm(100 * 2), nrow = 100) +colnames(zip_pred_data) <- c("V1", "V2") +suppressWarnings(zip_fit <- stan_gllvm( + dat_list = zip_sim_data, family = "zi_poisson", + refresh = 0, chains = 2, iter = 500 +)) +test_that("posterior_(lin)pred works with gllvm and zip", { + zip_pred <- posterior_predict(zip_fit, ndraws = 100) + + expect_length(zip_pred, 100) + expect_false(any(sapply(zip_pred, anyNA))) + expect_false(any(sapply(zip_pred, function(x) x < 0))) + + zip_pred2 <- posterior_predict(zip_fit, + newdata = zip_pred_data, + ndraws = 50, list_index = "species" + ) + + expect_length(zip_pred2, 7) + expect_false(any(sapply(zip_pred2, anyNA))) + expect_false(any(sapply(zip_pred2, function(x) x < 0))) +}) + +set.seed(9598098) +zinb_sim_data <- mglmm_sim_data(N = 100, S = 7, K = 2, family = "zi_neg_binomial", + site_intercept = "ungrouped", zi_param = "covariate") +zinb_pred_data <- matrix(rnorm(100 * 2), nrow = 100) +colnames(zinb_pred_data) <- c("V1", "V2") +suppressWarnings(zinb_fit <- stan_mglmm( + dat_list = zinb_sim_data, family = "zi_neg_binomial",zi_param="covariate", + refresh = 0, chains = 2, iter = 500 +)) +test_that("zi_neg_bin print works okay", { + expect_output(print(zinb_fit$family), + "is modelled in response to") +}) + +test_that("posterior_(lin)pred works with gllvm and zinb", { + zinb_pred <- posterior_predict(zinb_fit, ndraws = 100) + + expect_length(zinb_pred, 100) + expect_false(any(sapply(zinb_pred, anyNA))) + expect_false(any(sapply(zinb_pred, function(x) x < 0))) + + zinb_pred2 <- posterior_predict(zinb_fit, + newdata = zinb_pred_data, + ndraws = 50, list_index = "species" + ) + + expect_length(zinb_pred2, 7) + expect_false(any(sapply(zinb_pred2, anyNA))) + expect_false(any(sapply(zinb_pred2, function(x) x < 0))) +}) diff --git a/tests/testthat/test-sim_data_funs.R b/tests/testthat/test-sim_data_funs.R index 10793ac..f038477 100644 --- a/tests/testthat/test-sim_data_funs.R +++ b/tests/testthat/test-sim_data_funs.R @@ -1,3 +1,4 @@ +set.seed(23359) test_that("gllvm_sim_data errors with bad inputs", { expect_error( gllvm_sim_data( @@ -23,6 +24,12 @@ test_that("gllvm_sim_data errors with bad inputs", { "Ntrials must be a positive integer" ) + expect_error( + gllvm_sim_data(N = 200, S = 8, D = 2, family = "zi_poisson", + zi_param = "covariate", zi_k = -2), + "zi_k must be either NULL or a positive integer" + ) + }) test_that("gllvm_sim_data returns a list of correct length", { @@ -87,6 +94,12 @@ test_that("mglmm_sim_data returns a list of correct length", { "Y", "pars", "N", "S", "D", "K", "X", "Ntrials" )) expect_length(gllvm_sim$Ntrials, 100) + gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "zi_neg_binomial", method = "gllvm", + zi_param = "covariate") + expect_named(gllvm_sim, c( + "Y", "pars", "N", "S", "D", "K", "X", "zi_k", "zi_X" + )) + expect_equal(dim(gllvm_sim$Y),c(100,12)) }) test_that("jsdm_sim_data returns all appropriate pars", { @@ -103,7 +116,21 @@ test_that("jsdm_sim_data returns all appropriate pars", { "betas","a_bar","sigma_a","a","L","LV","sigma_L","kappa" )) + gllvm_sim2 <- jsdm_sim_data(100,12,D=2,family = "zi_poisson", method = "gllvm", + beta_param = "unstruct", + site_intercept = "ungrouped") + expect_named(gllvm_sim2$pars, c( + "betas","a_bar","sigma_a","a","L","LV","sigma_L","zi" + )) + + gllvm_sim3 <- jsdm_sim_data(100,9,K=2,D=2,family = "zi_neg_bin", method = "gllvm", + beta_param = "unstruct", + site_intercept = "ungrouped", zi_param = "covariate", + zi_k = 1) + expect_named(gllvm_sim3$pars, c( + "betas","a_bar","sigma_a","a","L","LV","sigma_L","zi_betas","kappa" + )) }) diff --git a/tests/testthat/test-stan_jsdm.R b/tests/testthat/test-stan_jsdm.R index 6dfab1f..f19764f 100644 --- a/tests/testthat/test-stan_jsdm.R +++ b/tests/testthat/test-stan_jsdm.R @@ -299,3 +299,23 @@ test_that("site intercept models run", { expect_s3_class(mglmm_fit, "jsdmStanFit") }) + + +set.seed(9598098) +zip_sim_data <- gllvm_sim_data(N = 100, S = 7, D = 2, K = 2, family = "zi_poisson", + zi_param = "covariate") + +test_that("zi_poisson works okay", { + suppressWarnings(zip_fit <- stan_gllvm( + dat_list = zip_sim_data, family = "zi_poisson",zi_param="covariate", + refresh = 0, chains = 2, iter = 200 + )) + expect_s3_class(zip_fit, "jsdmStanFit") + expect_s3_class(zip_fit$family, "jsdmStanFamily") + expect_output(print(zip_fit$family), + "is modelled in response to") + expect_named(zip_fit$family, + c("family" ,"params" ,"params_dataresp","preds","data_list")) + expect_named(zip_fit$family$data_list, "zi_X") +}) + diff --git a/tests/testthat/test-update.R b/tests/testthat/test-update.R index 1bee5e1..419b689 100644 --- a/tests/testthat/test-update.R +++ b/tests/testthat/test-update.R @@ -62,3 +62,27 @@ test_that("binomial models update", { )) expect_s3_class(gllvm_fit2, "jsdmStanFit") }) + +zip_data <- gllvm_sim_data(97,9,D = 2, family = "zi_poisson", + zi_param = "covariate") + +test_that("zi models update", { + suppressWarnings(gllvm_fit <- stan_gllvm(dat_list = zip_data, + family = "zi_poisson", + refresh = 0, chains = 1, iter = 200 + )) + expect_s3_class(gllvm_fit, "jsdmStanFit") + + expect_error(gllvm_fit2 <- update(gllvm_fit, newD = 3, + newZi_X = matrix(1:74, nrow = 74), + refresh = 0, chains = 1, iter = 200 + ), "Number of rows of zi_X") + + + suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 3, + newZi_X = matrix(rnorm(97), nrow = 97), + refresh = 0, chains = 1, iter = 200 + )) + expect_equal(ncol(gllvm_fit2$data_list$zi_X),2) + expect_s3_class(gllvm_fit2, "jsdmStanFit") +})