Skip to content

Commit

Permalink
Merge pull request #72 from NIEHS/main-sciome
Browse files Browse the repository at this point in the history
Faster imputation
  • Loading branch information
ericbair-sciome authored Aug 24, 2024
2 parents b814826 + 2ce73ce commit f9712d6
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 42 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: PrestoGP
Type: Package
Title: Penalized Regression for Spatio-Temporal Outcomes via Gaussian Processes
Version: 0.2.0.9033
Version: 0.2.0.9034
Authors@R: c(
person(given = "Eric",
family = "Bair",
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import(MASS)
import(Matrix)
import(fields)
import(glmnet)
importFrom(RANN,nn2)
importFrom(aod,wald.test)
importFrom(dplyr,"%>%")
importFrom(foreach,"%dopar%")
Expand All @@ -39,6 +40,7 @@ importFrom(rlang,enquos)
importFrom(stats,coef)
importFrom(stats,optim)
importFrom(stats,predict)
importFrom(stats,rnorm)
importFrom(stats,var)
importFrom(tmvtnorm,rtmvnorm)
importFrom(truncnorm,etruncnorm)
Expand Down
3 changes: 2 additions & 1 deletion R/PrestoGP-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
#' @importFrom methods validObject
#' @importFrom stats coef
#' @importFrom utils getFromNamespace
#' @importFrom stats optim predict var
#' @importFrom stats optim predict var rnorm
#' @importFrom aod wald.test
#' @importFrom dplyr %>%
#' @importFrom truncnorm rtruncnorm
#' @importFrom truncnorm etruncnorm
#' @importFrom tmvtnorm rtmvnorm
#' @importFrom RANN nn2
## usethis namespace: end
NULL
2 changes: 1 addition & 1 deletion R/PrestoGP_CreateU_Multivariate.R
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ knn_indices <- function(ordered_locs, query, n_neighbors, dist_func, dist_func_c
"distances" = dists[nearest_neighbors]
))
} else {
cur.nn <- RANN::nn2(ordered_locs, query, n_neighbors)
cur.nn <- nn2(ordered_locs, query, n_neighbors)
return(list("indices" = cur.nn$nn.idx, "distances" = cur.nn$nn.dists))
}
}
Expand Down
182 changes: 162 additions & 20 deletions R/PrestoGP_Multivariate_Vecchia.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ setMethod("impute_y_lod", "MultivariateVecchiaModel", function(model, lod,
vecchia.approx <- model@vecchia_approx
params <- model@covparams
param.seq <- model@param_sequence
P <- length(model@locs_train)
if (!model@apanasovich) {
olocs.scaled <- vecchia.approx$locsord
for (i in 1:vecchia.approx$P) {
Expand All @@ -343,17 +344,53 @@ setMethod("impute_y_lod", "MultivariateVecchiaModel", function(model, lod,
vecchia.approx$locsord <- olocs.scaled
params <- c(params[1:param.seq[1, 2]], rep(1, vecchia.approx$P),
params[param.seq[3, 1]:param.seq[5, 2]])
param.seq <- create.param.sequence(vecchia.approx$P)
}

U.obj <- createUMultivariate(vecchia.approx, params)
Sigma.hat <- solve(U.obj$U %*% t(U.obj$U))
Sigma.hat <- Sigma.hat[U.obj$latent, U.obj$latent]
Sigma.hat <- Sigma.hat[order(U.obj$ord), order(U.obj$ord)]
locs.scaled <- scale_locs(model, model@locs_train)
locs.all <- NULL
nugget.seq <- NULL
for (i in 1:vecchia.approx$P) {
ndx <- list()
for (i in seq_len(P)) {
locs.all <- rbind(locs.all, locs.scaled[[i]])
nugget.seq <- c(nugget.seq, rep(params[param.seq[4, 1] + i - 1],
nrow(model@locs_train[[i]])))
ndx[[i]] <- seq_len(nrow(locs.scaled[[i]]))
if (i > 1) {
last.ndx <- ndx[[i - 1]]
ndx[[i]] <- ndx[[i]] + last.ndx[length(last.ndx)]
}
}
locs.nn <- nn2(locs.all, k = model@n_neighbors)$nn.idx

sigma <- params[param.seq[1, 1]:param.seq[1, 2]]
rangep <- params[param.seq[2, 1]:param.seq[2, 2]]
smoothness <- params[param.seq[3, 1]:param.seq[3, 2]]

rho <- params[param.seq[5, 1]:param.seq[5, 2]]
rho.mat <- matrix(0, nrow = P, ncol = P)
rho.mat[upper.tri(rho.mat, diag = FALSE)] <- rho
rho.mat <- rho.mat + t(rho.mat)
diag(rho.mat) <- 1

Sigma.hat <- matrix(nrow = nrow(locs.all), ncol = nrow(locs.all))
for (i in seq_len(P)) {
for (j in i:P) {
smooth.ii <- smoothness[i]
smooth.jj <- smoothness[j]
smooth.ij <- (smooth.ii + smooth.jj) / 2
alpha.ii <- 1 / rangep[i]
alpha.jj <- 1 / rangep[j]
alpha.ij <- sqrt((alpha.ii^2 + alpha.jj^2) / 2)
Sigma.hat[ndx[[i]], ndx[[j]]] <- rho.mat[i, j] * sqrt(sigma[i]) *
sqrt(sigma[j]) * alpha.ii^smooth.ii * alpha.jj^smooth.jj *
gamma(smooth.ij) / (alpha.ij^(2 * smooth.ij) *
sqrt(gamma(smooth.ii) * gamma(smooth.jj))) *
Matern(rdist(locs.scaled[[i]], locs.scaled[[j]]),
alpha = alpha.ij, smoothness = smooth.ij)
if (i != j) {
Sigma.hat[ndx[[j]], ndx[[i]]] <- t(Sigma.hat[ndx[[i]], ndx[[j]]])
}
}
}
Sigma.hat <- Sigma.hat + nugget.seq * diag(nrow = nrow(Sigma.hat))

Expand All @@ -366,13 +403,65 @@ setMethod("impute_y_lod", "MultivariateVecchiaModel", function(model, lod,
yhat.ni <- X %*% cur.coef[-1]
yhat.ni <- yhat.ni + mean(y[!miss]) - mean(yhat.ni[!miss])

mu.miss <- as.vector(yhat.ni[miss] + Sigma.hat[miss, !miss] %*%
solve(Sigma.hat[!miss, !miss], y[!miss] - yhat.ni[!miss]))
Sigma.miss <- Sigma.hat[miss, miss] - Sigma.hat[miss, !miss] %*%
solve(Sigma.hat[!miss, !miss], Sigma.hat[!miss, miss])

y.na.mat <- rtmvnorm(n.mi, mean = mu.miss, sigma = Sigma.miss,
upper = lod[miss], algorithm = "gibbs")
if (parallel) {
y.na.mat <- foreach(i = which(miss), .combine = cbind) %dopar% {
cur.nn <- locs.nn[i, ]
cur.miss <- miss[cur.nn]
cur.y <- y[cur.nn]
cur.yhat <- yhat.ni[cur.nn]
cur.Sigma <- Sigma.hat[cur.nn, cur.nn]
if (sum(cur.miss) == length(cur.miss)) {
mu.miss <- cur.yhat
Sigma.miss <- cur.Sigma
} else {
mu.miss <- as.vector(cur.yhat[cur.miss] +
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.y[!cur.miss] - cur.yhat[!cur.miss]))
Sigma.miss <- cur.Sigma[cur.miss, cur.miss] -
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.Sigma[!cur.miss, cur.miss, drop = FALSE])
}
if (length(mu.miss) > 1) {
cur.out <- rtmvnorm(10, mean = mu.miss, sigma = Sigma.miss,
upper = lod[cur.nn][cur.miss], algorithm = "gibbs")[, 1]
} else {
cur.out <- rnorm(10, mu.miss, as.numeric(Sigma.miss))
}
cur.out
}
} else {
y.na.mat <- matrix(nrow = 10, ncol = sum(miss))

for (i in seq_len(ncol(y.na.mat))) {
cur.nn <- locs.nn[which(miss)[i], ]
cur.miss <- miss[cur.nn]
cur.y <- y[cur.nn]
cur.yhat <- yhat.ni[cur.nn]
cur.Sigma <- Sigma.hat[cur.nn, cur.nn]
if (sum(cur.miss) == length(cur.miss)) {
mu.miss <- cur.yhat
Sigma.miss <- cur.Sigma
} else {
mu.miss <- as.vector(cur.yhat[cur.miss] +
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.y[!cur.miss] - cur.yhat[!cur.miss]))
Sigma.miss <- cur.Sigma[cur.miss, cur.miss] -
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.Sigma[!cur.miss, cur.miss, drop = FALSE])
}
if (length(mu.miss) > 1) {
cur.out <- rtmvnorm(10, mean = mu.miss, sigma = Sigma.miss,
upper = lod[cur.nn][cur.miss], algorithm = "gibbs")[, 1]
} else {
cur.out <- rnorm(10, mu.miss, as.numeric(Sigma.miss))
}
y.na.mat[, i] <- cur.out
}
}

coef.mat <- matrix(nrow = n.mi, ncol = (ncol(X) + 1))
for (i in 1:n.mi) {
Expand All @@ -384,21 +473,74 @@ setMethod("impute_y_lod", "MultivariateVecchiaModel", function(model, lod,
cur.glmnet <- cv.glmnet(as.matrix(Xt), as.matrix(yt),
alpha = model@alpha, family = family, nfolds = nfolds, foldid = foldid,
parallel = parallel)
coef.mat[i, ] <- as.matrix(coef(cur.glmnet, cur.glmnet$lambda.min))
coef.mat[i, ] <- as.matrix(coef(cur.glmnet, s = "lambda.min"))
}
last.coef <- cur.coef
cur.coef <- colMeans(coef.mat)
}
yhat.ni <- X %*% cur.coef[-1]
yhat.ni <- yhat.ni + mean(y[!miss]) - mean(yhat.ni[!miss])

mu.miss <- as.vector(yhat.ni[miss] + Sigma.hat[miss, !miss] %*%
solve(Sigma.hat[!miss, !miss], y[!miss] - yhat.ni[!miss]))
Sigma.miss <- Sigma.hat[miss, miss] - Sigma.hat[miss, !miss] %*%
solve(Sigma.hat[!miss, !miss], Sigma.hat[!miss, miss])
if (parallel) {
y.na.mat <- foreach(i = which(miss), .combine = cbind) %dopar% {
cur.nn <- locs.nn[i, ]
cur.miss <- miss[cur.nn]
cur.y <- y[cur.nn]
cur.yhat <- yhat.ni[cur.nn]
cur.Sigma <- Sigma.hat[cur.nn, cur.nn]
if (sum(cur.miss) == length(cur.miss)) {
mu.miss <- cur.yhat
Sigma.miss <- cur.Sigma
} else {
mu.miss <- as.vector(cur.yhat[cur.miss] +
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.y[!cur.miss] - cur.yhat[!cur.miss]))
Sigma.miss <- cur.Sigma[cur.miss, cur.miss] -
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.Sigma[!cur.miss, cur.miss, drop = FALSE])
}
if (length(mu.miss) > 1) {
cur.out <- rtmvnorm(100, mean = mu.miss, sigma = Sigma.miss,
upper = lod[cur.nn][cur.miss], algorithm = "gibbs")[, 1]
} else {
cur.out <- rnorm(100, mu.miss, as.numeric(Sigma.miss))
}
cur.out
}
} else {
y.na.mat <- matrix(nrow = 100, ncol = sum(miss))

for (i in seq_len(ncol(y.na.mat))) {
cur.nn <- locs.nn[which(miss)[i], ]
cur.miss <- miss[cur.nn]
cur.y <- y[cur.nn]
cur.yhat <- yhat.ni[cur.nn]
cur.Sigma <- Sigma.hat[cur.nn, cur.nn]
if (sum(cur.miss) == length(cur.miss)) {
mu.miss <- cur.yhat
Sigma.miss <- cur.Sigma
} else {
mu.miss <- as.vector(cur.yhat[cur.miss] +
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.y[!cur.miss] - cur.yhat[!cur.miss]))
Sigma.miss <- cur.Sigma[cur.miss, cur.miss] -
cur.Sigma[cur.miss, !cur.miss, drop = FALSE] %*%
solve(cur.Sigma[!cur.miss, !cur.miss],
cur.Sigma[!cur.miss, cur.miss, drop = FALSE])
}
if (length(mu.miss) > 1) {
cur.out <- rtmvnorm(100, mean = mu.miss, sigma = Sigma.miss,
upper = lod[cur.nn][cur.miss], algorithm = "gibbs")[, 1]
} else {
cur.out <- rnorm(100, mu.miss, as.numeric(Sigma.miss))
}
y.na.mat[, i] <- cur.out
}
}

y.na.mat <- rtmvnorm(100, mean = mu.miss, sigma = Sigma.miss,
upper = lod[miss], algorithm = "gibbs")
model@Y_train[miss] <- colMeans(y.na.mat)
invisible(model)
})
Expand Down
Loading

0 comments on commit f9712d6

Please sign in to comment.