-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create vignette and add wrapper function
- Loading branch information
1 parent
0922ec1
commit f593eeb
Showing
18 changed files
with
953 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
^.*\.Rproj$ | ||
^\.Rproj\.user$ | ||
^somegraphs\.R$ | ||
^\.github$ | ||
^old$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#' @title Get prediction sets | ||
#' @description Let K be the total number of distinct cell type labels and n, m | ||
#' the number of cells in the calibration and in the test data, respectively. | ||
#' This function takes as input two matrices: a matrix \code{n x K} and | ||
#' a matrix \code{m x K} with the estimated | ||
#' probabilities for each cell in the calibration and in the test data, respectively. | ||
#' It returns a list with the prediction sets for each cell in the test data. | ||
#' | ||
#' @param x.query query data for which we want to build prediction sets. Could be either a | ||
#' SingleCellExperiment object with the estimated probabilities for each cell type | ||
#' in the colData, or a matrix of dimension \code{n x K}, where \code{n} is the number | ||
#' of cells and \code{K} is the number of different labels. The colnames of the matrix | ||
#' have to correspond to the cell labels. | ||
#' @param x.cal calibration data. Could be either a | ||
#' SingleCellExperiment object with the estimated probabilities for each cell type | ||
#' in the colData, or a named matrix of dimension \code{m x K}, where \code{m} is the number | ||
#' of cells and \code{K} is the number of different labels. The colnames of the matrix | ||
#' have to correspond to the cell labels. | ||
#' @param y.cal a vector of length n with the labels of the cells in the calibration data | ||
#' @param onto the considered section of the cell ontology as an igraph object. | ||
#' @param alpha a number between 0 and 1 that indicates the allowed miscoverage | ||
#' @param lambdas a vector of possible lambda values to be considered | ||
#' @param follow_ontology If \code{TRUE}, then the function returns hierarchical | ||
#' prediction sets that follow the cell ontology structure. If \code{FALSE}, it | ||
#' returns classical conformal prediction sets. | ||
#' @author Daniela Corbetta | ||
#' @return The function \code{getPredictionSets} returns a list of length equal | ||
#' to the number of cells in the test data. | ||
#' Each element of the list contains the prediction set for that cell. | ||
#' @references For an introduction to conformal prediction, see of | ||
#' Angelopoulos, Anastasios N., and Stephen Bates. "A gentle introduction to | ||
#' conformal prediction and distribution-free uncertainty quantification." arXiv preprint arXiv:2107.07511 (2021). | ||
#' For reference on conformal risk control, see | ||
#' Angelopoulos, Anastasios N., et al. "Conformal risk control." arXiv preprint arXiv:2208.02814 (2022). | ||
#' @importFrom foreach %dopar% | ||
#' @importFrom foreach foreach | ||
#' @importFrom SummarizedExperiment colData | ||
#' @export | ||
|
||
getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1, | ||
lambdas = lambdas <- seq(0.001,0.999,length.out=100), | ||
follow_ontology=TRUE){ | ||
# Add check to see if x.cal, x.query are SingleCell/SpatialExperiment or matrices | ||
# Retrieve labels from the ontology | ||
labels <- V(onto)$name[degree(onto, mode="out")==0] | ||
K <- length(labels) | ||
if(!is.matrix(x.query)){ | ||
n.query <- ncol(x.query) | ||
p.query <- matrix(NA, nrow=n.query, ncol=K) | ||
colnames(p.query) <- labels | ||
for(i in labels){ | ||
p.query[,i] <- colData(x.query)[[i]] | ||
} | ||
} | ||
else p.query <- x.query | ||
|
||
if(!is.matrix(x.cal)){ | ||
n.cal <- ncol(x.cal) | ||
p.cal <- matrix(NA, nrow=n.cal, ncol=K) | ||
colnames(p.cal) <- labels | ||
for(i in labels){ | ||
p.cal[,i] <- colData(x.cal)[[i]] | ||
} | ||
} | ||
else p.cal <- x.cal | ||
|
||
if (follow_ontology){ | ||
pred.sets <- .getHierarchicalPredSets(p.cal=p.cal, p.test=p.query, | ||
y.cal=y.cal, onto=onto, alpha=alpha, | ||
lambdas=lambdas)$sets.test | ||
} | ||
else | ||
pred.sets <- .getConformalPredSets(p.cal=p.cal, p.test=p.query, | ||
y.cal=y.cal, alpha=alpha) | ||
return(pred.sets) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#### Function to get the conformal quantile | ||
getConfQuant <- function(p, label, alpha){ | ||
#p prediction matrix for data in the calibration set (ncal x num labels) | ||
#label true label for data in the calibration set | ||
#alpha desired error rate | ||
true <- rep(NA, dim(p)[1]) | ||
for (i in 1:dim(p)[1]) | ||
true[i] <- p[i, label[i]] | ||
|
||
cal.scores <- 1-true | ||
|
||
# get adjusted quantile | ||
n <- dim(p)[1] | ||
q_level = ceiling((n+1)*(1-alpha))/n | ||
qhat = quantile(cal.scores, q_level) | ||
return(list(scores=cal.scores, qhat=qhat)) | ||
} | ||
|
||
#### Function to get the prediction sets for test data | ||
getPredSets <- function(pred, label, qhat, getClass=T, acc.p=NULL, summary=F){ | ||
#pred prediction matrix for data in the test set (ntest x num labels) | ||
#label true label for data in the test set (needed to evaluate accuracy and coverage) | ||
#qhat conformal quantile | ||
|
||
#get predicted class from prediction matrix | ||
if (getClass) | ||
acc.p <- apply(pred, 1, function(row) colnames(pred)[which.max(row)]) | ||
acc <- mean(acc.p==label) | ||
prediction_sets <- pred >=(1-qhat) | ||
rs <- rowSums(prediction_sets) # to have summary on size of prediction sets | ||
|
||
# Get prediction set colnames | ||
pr.list <- lapply(1:nrow(prediction_sets), function(i) { | ||
colnames(prediction_sets)[prediction_sets[i, ]] | ||
}) | ||
|
||
# Get coverage | ||
tf <- rep(NA, length(pr.list)) | ||
for (i in 1:length(pr.list)) | ||
tf[i] <- label[i] %in% pr.list[[i]] | ||
coverage <- sum(tf)/length(pr.list) | ||
|
||
if (summary) | ||
return(list(accuracy=acc, pred.sets=pr.list, coverage=coverage, summary=summary(rs))) | ||
else | ||
return(list(accuracy=acc, pred.sets=pr.list, coverage=coverage)) | ||
} | ||
|
||
#### Function to get class conditional coverage | ||
class_sp_conf <- function(classes, pred.cal, pred.test, labels.cal, labels.test){ | ||
s <- rep(NA, length(classes)) | ||
names(s) <- classes | ||
|
||
for (i in 1:length(classes)){ | ||
p <- as.matrix(pred.cal[labels.cal==classes[i],]) | ||
s[i] <- getConfQuant(p, rep(classes[i], nrow(p)), 0.05)$qhat | ||
} | ||
|
||
acc.p <- apply(pred.test, 1, function(row) colnames(pred.test)[which.max(row)]) | ||
p.sets <- matrix(NA, nrow=nrow(pred.test), ncol=ncol(pred.test)) | ||
colnames(p.sets) <- colnames(pred.test) | ||
for(i in 1:length(classes)) | ||
p.sets[,classes[i]] <- pred.test[,classes[i]] >= (1-s[classes[i]]) | ||
|
||
pr.list <- lapply(1:nrow(p.sets), function(i) { | ||
colnames(p.sets)[p.sets[i, ]] | ||
}) | ||
|
||
# Get coverage | ||
tf <- rep(NA, length(pr.list)) | ||
for (i in 1:length(pr.list)) | ||
tf[i] <- labels.test[i] %in% pr.list[[i]] | ||
return(list(pred.sets = pr.list, coverage=mean(tf))) | ||
} | ||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.