Skip to content

Commit

Permalink
Merge pull request #5 from NERC-CEH/zip
Browse files Browse the repository at this point in the history
Adding zero-inflated families
  • Loading branch information
fseaton authored Aug 21, 2024
2 parents 6a829b7 + cabf9ee commit 2c5d9dd
Show file tree
Hide file tree
Showing 32 changed files with 1,108 additions and 200 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: jsdmstan
Title: Fitting jSDMs in Stan
Version: 0.3.0
Version: 0.3.0.9000
Authors@R:
person("Fiona", "Seaton", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-2022-7451"))
Expand All @@ -12,7 +12,7 @@ License: GPL (>= 3)
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Biarch: true
Depends:
R (>= 3.4.0)
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(plot,jsdmStanFit)
S3method(posterior_linpred,jsdmStanFit)
S3method(posterior_predict,jsdmStanFit)
S3method(pp_check,jsdmStanFit)
S3method(print,jsdmStanFamily)
S3method(print,jsdmStanFit)
S3method(print,jsdmprior)
S3method(print,jsdmstan_model)
Expand Down Expand Up @@ -40,6 +41,7 @@ export(nuts_params)
export(ordiplot)
export(posterior_linpred)
export(posterior_predict)
export(posterior_zipred)
export(pp_check)
export(rgampois)
export(rhat)
Expand Down
183 changes: 160 additions & 23 deletions R/jsdm_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
#' can be modified using the prior object.
#'
#' @param method The method, one of \code{"gllvm"} or \code{"mglmm"}
#' @param family The family, one of "\code{"gaussian"}, \code{"bernoulli"},
#' \code{"poisson"} or \code{"neg_binomial"}
#' @param family is the response family, must be one of \code{"gaussian"},
#' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"},
#' \code{"bernoulli"}, \code{"zi_poisson"}, or
#' \code{"zi_neg_binomial"}. Regular expression
#' matching is supported.
#' @param prior The prior, given as the result of a call to [jsdm_prior()]
#' @param log_lik Whether the log likelihood should be calculated in the generated
#' quantities (by default \code{TRUE}), required for loo
Expand All @@ -24,6 +27,9 @@
#' grouping)
#' @param beta_param The parameterisation of the environmental covariate effects, by
#' default \code{"cor"}. See details for further information.
#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter
#' is a species-specific constant (default, \code{"constant"}), or varies by
#' environmental covariates (\code{"covariate"}).
#'
#' @return A character vector of Stan code, class "jsdmstan_model"
#' @export
Expand All @@ -34,13 +40,15 @@
#'
jsdm_stancode <- function(method, family, prior = jsdm_prior(),
log_lik = TRUE, site_intercept = "none",
beta_param = "cor") {
beta_param = "cor", zi_param = "constant") {
# checks
family <- match.arg(family, c("gaussian", "bernoulli", "poisson",
"neg_binomial","binomial"))
"neg_binomial","binomial","zi_poisson",
"zi_neg_binomial"))
method <- match.arg(method, c("gllvm", "mglmm"))
beta_param <- match.arg(beta_param, c("cor","unstruct"))
site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped"))
zi_param <- match.arg(zi_param, c("constant","covariate"))
if (class(prior)[1] != "jsdmprior") {
stop("Prior must be given as a jsdmprior object")
}
Expand All @@ -49,15 +57,15 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
scode <- .modelcode(
method = method, family = family,
phylo = FALSE, prior = prior, log_lik = log_lik, site_intercept = site_intercept,
beta_param = beta_param
beta_param = beta_param, zi_param = zi_param
)
class(scode) <- c("jsdmstan_model", "character")
return(scode)
}


.modelcode <- function(method, family, phylo, prior, log_lik, site_intercept,
beta_param) {
beta_param, zi_param) {
model_functions <- "
"
data <- paste(
Expand All @@ -81,12 +89,29 @@ ifelse(site_intercept == "grouped",
"bernoulli" = "int<lower=0,upper=1>",
"neg_binomial" = "int<lower=0>",
"poisson" = "int<lower=0>",
"zi_poisson" = "int<lower=0>",
"zi_neg_binomial" = "int<lower=0>",
"binomial" = "int<lower=0>"
), "Y[N,S]; //Species matrix",
ifelse(family == "binomial",
"
int<lower=0> Ntrials[N]; // Number of trials","")
)
int<lower=0> Ntrials[N]; // Number of trials",""),
ifelse(grepl("zi_", family),"
int<lower=0> N_zero[S]; // number of zeros per species
int<lower=0> N_nonzero[S]; //number of nonzeros per species
int<lower=0> Sum_nonzero; //Total number of nonzeros across all species
int<lower=0> Sum_zero; //Total number of zeros across all species
int<lower=0> Y_nz[Sum_nonzero]; //Y values for nonzeros
int<lower=0> ss[Sum_nonzero]; //species index for Y_nz
int<lower=0> nn[Sum_nonzero]; //site index for Y_nz
int<lower=0> sz[Sum_zero]; //species index for Y_z
int<lower=0> nz[Sum_zero]; //site index for Y_z",""),
ifelse(grepl("zi_", family) & zi_param == "covariate","
int<lower=1> zi_k; //number of covariates for env effects on zi
matrix[N, zi_k] zi_X; //environmental covariate matrix for zi","")
)


transformed_data <- ifelse(method == "gllvm", "
// Ensures identifiability of the model - no rotation of factors
int<lower=1> M;
Expand Down Expand Up @@ -131,7 +156,19 @@ ifelse(site_intercept == "grouped",
"bernoulli" = "",
"neg_binomial" = "
real<lower=0> kappa[S]; // neg_binomial parameters",
"poisson" = ""
"poisson" = "",
"zi_poisson" = switch(zi_param,
"constant" = "
real<lower=0,upper=1> zi[S]; // zero-inflation parameter",
"covariate" = "
matrix[zi_k,S] zi_betas; //environmental effects for zi"),
"zi_neg_binomial" = switch(zi_param,
"constant" = "
real<lower=0> kappa[S]; // neg_binomial parameters
real<lower=0,upper=1> zi[S]; // zero-inflation parameter",
"covariate" = "
real<lower=0> kappa[S]; // neg_binomial parameters
matrix[zi_k,S] zi_betas; //environmental effects for zi")
)

pars <- paste(
Expand Down Expand Up @@ -215,10 +252,32 @@ ifelse(site_intercept == "grouped",
")
model <- paste("
matrix[N,S] mu;
", switch(method,
", ifelse(grepl("zi_",family),paste0("
real mu_nz[Sum_nonzero];
real mu_z[Sum_zero];
int pos;
int neg;",switch(zi_param,"constant" = "",
"covariate" = "
real zi_nz[Sum_nonzero];
real zi_z[Sum_zero];")),""),
switch(method,
"gllvm" = gllvm_model,
"mglmm" = mglmm_model
))
),ifelse(grepl("zi_",family),paste0(ifelse(zi_param == "covariate", "
matrix[N,S] zi = zi_X * zi_betas;",""),"
for(i in 1:Sum_nonzero){
mu_nz[i] = mu[nn[i],ss[i]];",
switch(zi_param, "constant" = "",
"covariate" = "
zi_nz[i] = zi[nn[i],ss[i]];"),"
}
for(i in 1:Sum_zero){
mu_z[i] = mu[nz[i],sz[i]];",
switch(zi_param, "constant" = "",
"covariate" = "
zi_z[i] = zi[nz[i],sz[i]];"),"
}
"),""))
model_priors <- paste(
ifelse(site_intercept %in% c("ungrouped","grouped"), paste("
// Site-level intercept priors
Expand Down Expand Up @@ -263,10 +322,26 @@ ifelse(site_intercept == "grouped",
"),
"bern" = "",
"poisson" = "",
"binomial" = ""
"binomial" = "",
"zi_poisson" = switch(zi_param,"constant" = paste("
//zero-inflation parameter
zi ~ ", prior[["zi"]], ";
"), "covariate" = paste("
//zero-inflation parameter
to_vector(zi_betas) ~ ", prior[["zi_betas"]], ";
")),
"zi_neg_binomial" = switch(zi_param, "constant" = paste("
//zero-inflation parameter
zi ~ ", prior[["zi"]], ";
kappa ~ ", prior[["kappa"]], ";
"), "covariate" = paste("
//zero-inflation parameter
to_vector(zi_betas) ~ ", prior[["zi_betas"]], ";
kappa ~ ", prior[["kappa"]], ";
")
)
)
model_pt2 <- paste(
))
model_pt2 <- if(!grepl("zi_", family)){ paste(
"
for(i in 1:N) Y[i,] ~ ",
switch(family,
Expand All @@ -276,7 +351,36 @@ ifelse(site_intercept == "grouped",
"poisson" = "poisson_log(mu[i,]);",
"binomial" = "binomial_logit(Ntrials[i], mu[i,]);"
)
)
)} else{paste("
pos = 1;
neg = 1;
for(s in 1:S){
target
+= N_zero[s]
* log_sum_exp(",
switch(zi_param,"constant" = "log(zi[s]),
log1m(zi[s])
+",
"covariate" = "bernoulli_logit_lpmf(1 | segment(zi_z, neg, N_zero[s])),
bernoulli_logit_lpmf(0 | segment(zi_z, neg, N_zero[s]))
+"),
switch(family,
"zi_poisson" = "poisson_log_lpmf(0 | segment(mu_z, neg, N_zero[s])));",
"zi_neg_binomial" = "neg_binomial_2_log_lpmf(0 | segment(mu_z, neg, N_zero[s]), kappa[s]));"),"
target += N_nonzero[s] * ",switch(zi_param,
"constant" = "log1m(zi[s]);",
"covariate" = "bernoulli_logit_lpmf(0 | segment(zi_nz, pos, N_nonzero[s]));"),"
target +=",
switch(family,
"zi_poisson" = "poisson_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) |
segment(mu_nz, pos, N_nonzero[s]));",
"zi_neg_binomial" = "neg_binomial_2_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) |
segment(mu_nz, pos, N_nonzero[s]), kappa[s]);"),"
pos = pos + N_nonzero[s];
neg = neg + N_zero[s];
}
")
}

generated_quantities <- paste(
ifelse(isTRUE(log_lik), "
Expand All @@ -298,7 +402,9 @@ ifelse(site_intercept == "grouped",
}", ""), ifelse(isTRUE(log_lik), paste(
"
{
matrix[N, S] linpred;", switch(site_intercept, "ungrouped" = paste("
matrix[N, S] linpred;",ifelse(grepl("zi", family) & zi_param == "covariate","
matrix[N,S] zi = zi_X * zi_betas;",""),
switch(site_intercept, "ungrouped" = paste("
linpred = rep_matrix(a_bar + a * sigma_a, S) + (X * betas) +",
switch(method,
"gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'",
Expand All @@ -322,14 +428,45 @@ ifelse(site_intercept == "grouped",
), ";
")),"
for(i in 1:N) {
for(j in 1:S) {
log_lik[i, j] = ",
for(j in 1:S) {",
switch(family,
"gaussian" = "normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);",
"bernoulli" = "bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);",
"neg_binomial" = "neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);",
"poisson" = "poisson_log_lpmf(Y[i, j] | linpred[i, j]);",
"binomial" = "binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);"
"gaussian" = "log_lik[i, j] = normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);",
"bernoulli" = "log_lik[i, j] = bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);",
"neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);",
"poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);",
"binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);",
"zi_poisson" = switch(zi_param,"constant" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]),
bernoulli_lpmf(0 |zi[j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]));
} else {
log_lik[i, j] = bernoulli_lpmf(0 | zi[j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]);
}",
"covariate" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]),
bernoulli_logit_lpmf(0 |zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]));
} else {
log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]);
}"),
"zi_neg_binomial" = switch(zi_param,"constant" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]),
bernoulli_lpmf(0 |zi[j])
+ neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]));
} else {
log_lik[i, j] = bernoulli_lpmf(0 | zi[j])
+ neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]);
}",
"covariate" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]),
bernoulli_logit_lpmf(0 |zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]));
} else {
log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]);
}")
),"
}
}
Expand Down
66 changes: 66 additions & 0 deletions R/jsdmstan-families.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#' jsdmStanFamily class
#'
#' This is the jsdmStanFamily class, which occupies a slot within any
#' jsdmStanFit object.
#'
#' @name jsdmStanFamily
#'
#' @section Elements for \code{jsdmStanFamily} objects:
#' \describe{
#' \item{\code{family}}{
#' A length one character vector describing family used to fit object. Options
#' are \code{"gaussian"}, \code{"poisson"}, \code{"bernoulli"},
#' \code{"neg_binomial"}, \code{"binomial"}, \code{"zi_poisson"},
#' \code{"zi_neg_binomial"}, or \code{"multiple"}.
#' }
#' \item{\code{params}}{
#' A character vector that includes all the names of the family-specific parameters.
#' }
#' \item{\code{params_dataresp}}{
#' A character vector that includes any named family-specific parameters that are
#' modelled in response to data.
#' }
#' \item{\code{preds}}{
#' A character vector of the measured predictors included if family parameters
#' are modelled in response to data. If family parameters are not modelled in
#' response to data this is left empty.
#' }
#' \item{\code{data_list}}{
#' A list containing the original data used to fit the model
#' (empty when save_data is set to \code{FALSE} or family parameters are not
#' modelled in response to data).
#' }
#' }
#'
jsdmStanFamily_empty <- function(){
res <- list(family = character(),
params = character(),
params_dataresp= character(),
preds = character(),
data_list = list())
class(res) <- "jsdmStanFamily"
return(res)
}

# jsdmStanFamily methods

#' Print jsdmStanFamily object
#'
#' @param x A jsdmStanFamily object
#' @param ... Other arguments, not used at this stage.
#'
#' @export
print.jsdmStanFamily <- function(x, ...){
cat(paste("Family:", x$family, "\n",
ifelse(length(x$params)>0,
paste("With parameters:",
paste0(x$params, sep = ", "),"\n"),
"")))
if(length(x$params_dataresp)>0){
cat(paste("Family-specific parameter",
paste0(x$params_dataresp,sep=", "),
"is modelled in response to", length(x$preds),
"predictors. These are named:",
paste0(x$preds, sep = ", ")))
}
}
Loading

0 comments on commit 2c5d9dd

Please sign in to comment.