Skip to content

Commit

Permalink
create vignette and add wrapper function
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielaCorbetta committed May 2, 2024
1 parent 0922ec1 commit f593eeb
Show file tree
Hide file tree
Showing 18 changed files with 953 additions and 15 deletions.
Binary file modified .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
^.*\.Rproj$
^\.Rproj\.user$
^somegraphs\.R$
^\.github$
^old$
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ LazyData: true
Depends:
R (>= 4.2.0)
Suggests:
scRNAseq,
knitr,
Matrix,
rmarkdown
rmarkdown,
BiocStyle
Imports:
foreach,
igraph,
stats
stats,
SummarizedExperiment
VignetteBuilder: knitr
RoxygenNote: 7.3.1
Encoding: UTF-8
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

export(getPredictionSets)
importFrom(SummarizedExperiment,colData)
importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(igraph,V)
Expand Down
Binary file added R/.DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion R/children.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
if(leaf)
return(V(onto)$name[is.finite(distances(onto, node, mode="out")) & degree(onto, mode="out")==0])
else
return(V(onto)$name[is.finite(distances(onto, node, mode="out")) & V(onto)$name!=node])
return(V(onto)$name[is.finite(distances(onto, node, mode="out"))])
}
5 changes: 4 additions & 1 deletion R/getConformalPredSets.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

.getConformalPredSets <- function(p.cal, p.test, y.cal, alpha){
# Get calibration scores (1-predicted probability for the true class)
s <- 1 - apply(p.cal, 1, function(row) row[y.cal])
true <- rep(NA, dim(p.cal)[1])
for (i in 1:dim(p.cal)[1])
true[i] <- p.cal[i, y.cal[i]]
s <- 1-true

# Get adjusted quantile
n <- nrow(p.cal)
Expand Down
7 changes: 4 additions & 3 deletions R/getHierarchicalPredSets.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
#' @importFrom foreach %dopar%
#' @importFrom foreach foreach


.getHierarchicalPredSets <- function(p.cal, p.test, y.cal, onto, alpha, lambdas){
y.cal <- as.character(y.cal)
# Get prediction sets for each value of lambda for all the calibration data
j <- NULL
sets <- foreach(j = lambdas) %dopar% {
exportedFn = c(".predSets", ".scores", ".children", ".ancestors")
sets <- foreach(j = lambdas, .export=exportedFn) %dopar% {
lapply(1:nrow(p.cal),
function(i) .predSets(lambda=j, pred=p.cal[i, ], onto=onto))
}
Expand All @@ -44,10 +44,11 @@
lhat_idx <- min(which(((n/(n+1)) * rhat + 1/(n+1) ) <= alpha))
lhat <- lambdas[lhat_idx]


# Get prediction sets for test data
sets.test <- apply(p.test, 1, function(x) .predSets(lambda=lhat, pred=x, onto=onto))

return(sets.test)
return(list(sets.test=sets.test, lhat=lhat))
}


Expand Down
76 changes: 76 additions & 0 deletions R/getPredictionSets.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#' @title Get prediction sets
#' @description Let K be the total number of distinct cell type labels and n, m
#' the number of cells in the calibration and in the test data, respectively.
#' This function takes as input two matrices: a matrix \code{n x K} and
#' a matrix \code{m x K} with the estimated
#' probabilities for each cell in the calibration and in the test data, respectively.
#' It returns a list with the prediction sets for each cell in the test data.
#'
#' @param x.query query data for which we want to build prediction sets. Could be either a
#' SingleCellExperiment object with the estimated probabilities for each cell type
#' in the colData, or a matrix of dimension \code{n x K}, where \code{n} is the number
#' of cells and \code{K} is the number of different labels. The colnames of the matrix
#' have to correspond to the cell labels.
#' @param x.cal calibration data. Could be either a
#' SingleCellExperiment object with the estimated probabilities for each cell type
#' in the colData, or a named matrix of dimension \code{m x K}, where \code{m} is the number
#' of cells and \code{K} is the number of different labels. The colnames of the matrix
#' have to correspond to the cell labels.
#' @param y.cal a vector of length n with the labels of the cells in the calibration data
#' @param onto the considered section of the cell ontology as an igraph object.
#' @param alpha a number between 0 and 1 that indicates the allowed miscoverage
#' @param lambdas a vector of possible lambda values to be considered
#' @param follow_ontology If \code{TRUE}, then the function returns hierarchical
#' prediction sets that follow the cell ontology structure. If \code{FALSE}, it
#' returns classical conformal prediction sets.
#' @author Daniela Corbetta
#' @return The function \code{getPredictionSets} returns a list of length equal
#' to the number of cells in the test data.
#' Each element of the list contains the prediction set for that cell.
#' @references For an introduction to conformal prediction, see of
#' Angelopoulos, Anastasios N., and Stephen Bates. "A gentle introduction to
#' conformal prediction and distribution-free uncertainty quantification." arXiv preprint arXiv:2107.07511 (2021).
#' For reference on conformal risk control, see
#' Angelopoulos, Anastasios N., et al. "Conformal risk control." arXiv preprint arXiv:2208.02814 (2022).
#' @importFrom foreach %dopar%
#' @importFrom foreach foreach
#' @importFrom SummarizedExperiment colData
#' @export

getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
lambdas = lambdas <- seq(0.001,0.999,length.out=100),
follow_ontology=TRUE){
# Add check to see if x.cal, x.query are SingleCell/SpatialExperiment or matrices
# Retrieve labels from the ontology
labels <- V(onto)$name[degree(onto, mode="out")==0]
K <- length(labels)
if(!is.matrix(x.query)){
n.query <- ncol(x.query)
p.query <- matrix(NA, nrow=n.query, ncol=K)
colnames(p.query) <- labels
for(i in labels){
p.query[,i] <- colData(x.query)[[i]]
}
}
else p.query <- x.query

if(!is.matrix(x.cal)){
n.cal <- ncol(x.cal)
p.cal <- matrix(NA, nrow=n.cal, ncol=K)
colnames(p.cal) <- labels
for(i in labels){
p.cal[,i] <- colData(x.cal)[[i]]
}
}
else p.cal <- x.cal

if (follow_ontology){
pred.sets <- .getHierarchicalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, onto=onto, alpha=alpha,
lambdas=lambdas)$sets.test
}
else
pred.sets <- .getConformalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, alpha=alpha)
return(pred.sets)
}
80 changes: 80 additions & 0 deletions R/old/conformal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#### Function to get the conformal quantile
getConfQuant <- function(p, label, alpha){
#p prediction matrix for data in the calibration set (ncal x num labels)
#label true label for data in the calibration set
#alpha desired error rate
true <- rep(NA, dim(p)[1])
for (i in 1:dim(p)[1])
true[i] <- p[i, label[i]]

cal.scores <- 1-true

# get adjusted quantile
n <- dim(p)[1]
q_level = ceiling((n+1)*(1-alpha))/n
qhat = quantile(cal.scores, q_level)
return(list(scores=cal.scores, qhat=qhat))
}

#### Function to get the prediction sets for test data
getPredSets <- function(pred, label, qhat, getClass=T, acc.p=NULL, summary=F){
#pred prediction matrix for data in the test set (ntest x num labels)
#label true label for data in the test set (needed to evaluate accuracy and coverage)
#qhat conformal quantile

#get predicted class from prediction matrix
if (getClass)
acc.p <- apply(pred, 1, function(row) colnames(pred)[which.max(row)])
acc <- mean(acc.p==label)
prediction_sets <- pred >=(1-qhat)
rs <- rowSums(prediction_sets) # to have summary on size of prediction sets

# Get prediction set colnames
pr.list <- lapply(1:nrow(prediction_sets), function(i) {
colnames(prediction_sets)[prediction_sets[i, ]]
})

# Get coverage
tf <- rep(NA, length(pr.list))
for (i in 1:length(pr.list))
tf[i] <- label[i] %in% pr.list[[i]]
coverage <- sum(tf)/length(pr.list)

if (summary)
return(list(accuracy=acc, pred.sets=pr.list, coverage=coverage, summary=summary(rs)))
else
return(list(accuracy=acc, pred.sets=pr.list, coverage=coverage))
}

#### Function to get class conditional coverage
class_sp_conf <- function(classes, pred.cal, pred.test, labels.cal, labels.test){
s <- rep(NA, length(classes))
names(s) <- classes

for (i in 1:length(classes)){
p <- as.matrix(pred.cal[labels.cal==classes[i],])
s[i] <- getConfQuant(p, rep(classes[i], nrow(p)), 0.05)$qhat
}

acc.p <- apply(pred.test, 1, function(row) colnames(pred.test)[which.max(row)])
p.sets <- matrix(NA, nrow=nrow(pred.test), ncol=ncol(pred.test))
colnames(p.sets) <- colnames(pred.test)
for(i in 1:length(classes))
p.sets[,classes[i]] <- pred.test[,classes[i]] >= (1-s[classes[i]])

pr.list <- lapply(1:nrow(p.sets), function(i) {
colnames(p.sets)[p.sets[i, ]]
})

# Get coverage
tf <- rep(NA, length(pr.list))
for (i in 1:length(pr.list))
tf[i] <- labels.test[i] %in% pr.list[[i]]
return(list(pred.sets = pr.list, coverage=mean(tf)))
}






Loading

0 comments on commit f593eeb

Please sign in to comment.