diff --git a/.DS_Store b/.DS_Store index cbe5a01..9d4009f 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.Rbuildignore b/.Rbuildignore index 91114bf..c56f671 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -1,2 +1,5 @@ ^.*\.Rproj$ ^\.Rproj\.user$ +^somegraphs\.R$ +^\.github$ +^old$ diff --git a/DESCRIPTION b/DESCRIPTION index a4f5a38..6f504cf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -13,14 +13,15 @@ LazyData: true Depends: R (>= 4.2.0) Suggests: - scRNAseq, knitr, Matrix, - rmarkdown + rmarkdown, + BiocStyle Imports: foreach, igraph, - stats + stats, + SummarizedExperiment VignetteBuilder: knitr RoxygenNote: 7.3.1 Encoding: UTF-8 diff --git a/NAMESPACE b/NAMESPACE index a99151e..27fe9ef 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +export(getPredictionSets) +importFrom(SummarizedExperiment,colData) importFrom(foreach,"%dopar%") importFrom(foreach,foreach) importFrom(igraph,V) diff --git a/R/.DS_Store b/R/.DS_Store new file mode 100644 index 0000000..f34df6c Binary files /dev/null and b/R/.DS_Store differ diff --git a/R/children.R b/R/children.R index c72b74c..e425946 100644 --- a/R/children.R +++ b/R/children.R @@ -14,5 +14,5 @@ if(leaf) return(V(onto)$name[is.finite(distances(onto, node, mode="out")) & degree(onto, mode="out")==0]) else - return(V(onto)$name[is.finite(distances(onto, node, mode="out")) & V(onto)$name!=node]) + return(V(onto)$name[is.finite(distances(onto, node, mode="out"))]) } diff --git a/R/getConformalPredSets.R b/R/getConformalPredSets.R index b9925c0..c468e77 100644 --- a/R/getConformalPredSets.R +++ b/R/getConformalPredSets.R @@ -20,7 +20,10 @@ .getConformalPredSets <- function(p.cal, p.test, y.cal, alpha){ # Get calibration scores (1-predicted probability for the true class) - s <- 1 - apply(p.cal, 1, function(row) row[y.cal]) + true <- rep(NA, dim(p.cal)[1]) + for (i in 1:dim(p.cal)[1]) + true[i] <- p.cal[i, y.cal[i]] + s <- 1-true # Get adjusted quantile n <- nrow(p.cal) diff --git a/R/getHierarchicalPredSets.R b/R/getHierarchicalPredSets.R index efbad6d..d3867f9 100644 --- a/R/getHierarchicalPredSets.R +++ b/R/getHierarchicalPredSets.R @@ -21,12 +21,12 @@ #' @importFrom foreach %dopar% #' @importFrom foreach foreach - .getHierarchicalPredSets <- function(p.cal, p.test, y.cal, onto, alpha, lambdas){ y.cal <- as.character(y.cal) # Get prediction sets for each value of lambda for all the calibration data j <- NULL - sets <- foreach(j = lambdas) %dopar% { + exportedFn = c(".predSets", ".scores", ".children", ".ancestors") + sets <- foreach(j = lambdas, .export=exportedFn) %dopar% { lapply(1:nrow(p.cal), function(i) .predSets(lambda=j, pred=p.cal[i, ], onto=onto)) } @@ -44,10 +44,11 @@ lhat_idx <- min(which(((n/(n+1)) * rhat + 1/(n+1) ) <= alpha)) lhat <- lambdas[lhat_idx] + # Get prediction sets for test data sets.test <- apply(p.test, 1, function(x) .predSets(lambda=lhat, pred=x, onto=onto)) - return(sets.test) + return(list(sets.test=sets.test, lhat=lhat)) } diff --git a/R/getPredictionSets.R b/R/getPredictionSets.R new file mode 100644 index 0000000..5274154 --- /dev/null +++ b/R/getPredictionSets.R @@ -0,0 +1,76 @@ +#' @title Get prediction sets +#' @description Let K be the total number of distinct cell type labels and n, m +#' the number of cells in the calibration and in the test data, respectively. +#' This function takes as input two matrices: a matrix \code{n x K} and +#' a matrix \code{m x K} with the estimated +#' probabilities for each cell in the calibration and in the test data, respectively. +#' It returns a list with the prediction sets for each cell in the test data. +#' +#' @param x.query query data for which we want to build prediction sets. Could be either a +#' SingleCellExperiment object with the estimated probabilities for each cell type +#' in the colData, or a matrix of dimension \code{n x K}, where \code{n} is the number +#' of cells and \code{K} is the number of different labels. The colnames of the matrix +#' have to correspond to the cell labels. +#' @param x.cal calibration data. Could be either a +#' SingleCellExperiment object with the estimated probabilities for each cell type +#' in the colData, or a named matrix of dimension \code{m x K}, where \code{m} is the number +#' of cells and \code{K} is the number of different labels. The colnames of the matrix +#' have to correspond to the cell labels. +#' @param y.cal a vector of length n with the labels of the cells in the calibration data +#' @param onto the considered section of the cell ontology as an igraph object. +#' @param alpha a number between 0 and 1 that indicates the allowed miscoverage +#' @param lambdas a vector of possible lambda values to be considered +#' @param follow_ontology If \code{TRUE}, then the function returns hierarchical +#' prediction sets that follow the cell ontology structure. If \code{FALSE}, it +#' returns classical conformal prediction sets. +#' @author Daniela Corbetta +#' @return The function \code{getPredictionSets} returns a list of length equal +#' to the number of cells in the test data. +#' Each element of the list contains the prediction set for that cell. +#' @references For an introduction to conformal prediction, see of +#' Angelopoulos, Anastasios N., and Stephen Bates. "A gentle introduction to +#' conformal prediction and distribution-free uncertainty quantification." arXiv preprint arXiv:2107.07511 (2021). +#' For reference on conformal risk control, see +#' Angelopoulos, Anastasios N., et al. "Conformal risk control." arXiv preprint arXiv:2208.02814 (2022). +#' @importFrom foreach %dopar% +#' @importFrom foreach foreach +#' @importFrom SummarizedExperiment colData +#' @export + +getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1, + lambdas = lambdas <- seq(0.001,0.999,length.out=100), + follow_ontology=TRUE){ + # Add check to see if x.cal, x.query are SingleCell/SpatialExperiment or matrices + # Retrieve labels from the ontology + labels <- V(onto)$name[degree(onto, mode="out")==0] + K <- length(labels) + if(!is.matrix(x.query)){ + n.query <- ncol(x.query) + p.query <- matrix(NA, nrow=n.query, ncol=K) + colnames(p.query) <- labels + for(i in labels){ + p.query[,i] <- colData(x.query)[[i]] + } + } + else p.query <- x.query + + if(!is.matrix(x.cal)){ + n.cal <- ncol(x.cal) + p.cal <- matrix(NA, nrow=n.cal, ncol=K) + colnames(p.cal) <- labels + for(i in labels){ + p.cal[,i] <- colData(x.cal)[[i]] + } + } + else p.cal <- x.cal + + if (follow_ontology){ + pred.sets <- .getHierarchicalPredSets(p.cal=p.cal, p.test=p.query, + y.cal=y.cal, onto=onto, alpha=alpha, + lambdas=lambdas)$sets.test + } + else + pred.sets <- .getConformalPredSets(p.cal=p.cal, p.test=p.query, + y.cal=y.cal, alpha=alpha) + return(pred.sets) +} diff --git a/R/old/conformal.R b/R/old/conformal.R new file mode 100644 index 0000000..58b553b --- /dev/null +++ b/R/old/conformal.R @@ -0,0 +1,80 @@ +#### Function to get the conformal quantile +getConfQuant <- function(p, label, alpha){ + #p prediction matrix for data in the calibration set (ncal x num labels) + #label true label for data in the calibration set + #alpha desired error rate + true <- rep(NA, dim(p)[1]) + for (i in 1:dim(p)[1]) + true[i] <- p[i, label[i]] + + cal.scores <- 1-true + + # get adjusted quantile + n <- dim(p)[1] + q_level = ceiling((n+1)*(1-alpha))/n + qhat = quantile(cal.scores, q_level) + return(list(scores=cal.scores, qhat=qhat)) +} + +#### Function to get the prediction sets for test data +getPredSets <- function(pred, label, qhat, getClass=T, acc.p=NULL, summary=F){ + #pred prediction matrix for data in the test set (ntest x num labels) + #label true label for data in the test set (needed to evaluate accuracy and coverage) + #qhat conformal quantile + + #get predicted class from prediction matrix + if (getClass) + acc.p <- apply(pred, 1, function(row) colnames(pred)[which.max(row)]) + acc <- mean(acc.p==label) + prediction_sets <- pred >=(1-qhat) + rs <- rowSums(prediction_sets) # to have summary on size of prediction sets + + # Get prediction set colnames + pr.list <- lapply(1:nrow(prediction_sets), function(i) { + colnames(prediction_sets)[prediction_sets[i, ]] + }) + + # Get coverage + tf <- rep(NA, length(pr.list)) + for (i in 1:length(pr.list)) + tf[i] <- label[i] %in% pr.list[[i]] + coverage <- sum(tf)/length(pr.list) + + if (summary) + return(list(accuracy=acc, pred.sets=pr.list, coverage=coverage, summary=summary(rs))) + else + return(list(accuracy=acc, pred.sets=pr.list, coverage=coverage)) +} + +#### Function to get class conditional coverage +class_sp_conf <- function(classes, pred.cal, pred.test, labels.cal, labels.test){ + s <- rep(NA, length(classes)) + names(s) <- classes + + for (i in 1:length(classes)){ + p <- as.matrix(pred.cal[labels.cal==classes[i],]) + s[i] <- getConfQuant(p, rep(classes[i], nrow(p)), 0.05)$qhat + } + + acc.p <- apply(pred.test, 1, function(row) colnames(pred.test)[which.max(row)]) + p.sets <- matrix(NA, nrow=nrow(pred.test), ncol=ncol(pred.test)) + colnames(p.sets) <- colnames(pred.test) + for(i in 1:length(classes)) + p.sets[,classes[i]] <- pred.test[,classes[i]] >= (1-s[classes[i]]) + + pr.list <- lapply(1:nrow(p.sets), function(i) { + colnames(p.sets)[p.sets[i, ]] + }) + + # Get coverage + tf <- rep(NA, length(pr.list)) + for (i in 1:length(pr.list)) + tf[i] <- labels.test[i] %in% pr.list[[i]] + return(list(pred.sets = pr.list, coverage=mean(tf))) +} + + + + + + diff --git a/R/old/somegraphs.R b/R/old/somegraphs.R new file mode 100644 index 0000000..e9079a5 --- /dev/null +++ b/R/old/somegraphs.R @@ -0,0 +1,174 @@ +############################################################################### +################### Test repository ########################################### +############################################################################### + +##################### +# Build example graph +##################### +library(igraph) +t <- graph_from_literal(animale-+cane:gatto:topo, gatto-+british:persiano, + gatto-+retriever, + cane-+cocker:retriever, retriever-+golden:labrador) +plot(t, layout=layout_as_tree(t, root="animale")) +p <- c(0.2, 0.3, 0.2, 0.10, 0.15,0.05) +names(p) <- c("cocker", "golden", "labrador", "british", "persiano", "topo") +p + + +###################### +### Load old functions +###################### +source("/Users/daniela/Documents/GitHub/ConfCell/R/old/conformal.R") +source("/Users/daniela/Documents/GitHub/ConfCell/R/old/utils_Hier.R") + +# Check that scores are the same +nam <- V(t)$name +s <- sapply(nam, function(x) g(p, x, t)) +s1 <- sapply(nam, function(x) .scores(p, x, t)) +sum(s!=s1) # ok + +# Check ancestors +anc <- sapply(nam, function(x) ancestors(x, graph = t, include_self = T)) +anc1 <- sapply(nam, function(x) .ancestors(x, onto = t, include_self = T)) +tf <- rep(NA, length(anc)) +for(i in 1:length(anc)){ + tf[i] <- mean(anc[[i]]==anc1[[i]]) +} +tf + +anc <- sapply(nam, function(x) ancestors(x, graph = t, include_self = F)) +anc1 <- sapply(nam, function(x) .ancestors(x, onto = t, include_self = F)) +tf <- rep(NA, length(anc)) +for(i in 1:length(anc)){ + tf[i] <- mean(anc[[i]]==anc1[[i]]) +} +tf + +# Check children +child <- sapply(nam, function(x) children(x, graph = t, leaf = T)) +child1 <- sapply(nam, function(x) .children(x, onto = t, leaf = T)) +tf <- rep(NA, length(child)) +for(i in 1:length(child)){ + tf[i] <- mean(child[[i]]==child1[[i]]) +} +tf + +child <- sapply(nam, function(x) children(x, graph = t, leaf = F)) +child1 <- sapply(nam, function(x) .children(x, onto = t, leaf = F)) +tf <- rep(NA, length(child)) +for(i in 1:length(child)){ + tf[i] <- mean(child[[i]]==child1[[i]]) +} +tf + +# .predSets +pred_sets1(0.5, p, t) +.predSets(lambda=0.5, pred=p, onto=t) + +pred_sets1(0.7, p, t) +.predSets(lambda=0.7, pred=p, onto=t) + +pred_sets1(0.75, p, t) +.predSets(lambda=0.75, pred=p, onto=t) + +pred_sets1(0.8, p, t) #should be all +.predSets(lambda=0.8, pred=p, onto=t) + +pred_sets1(1, p, t) +.predSets(lambda=1, pred=p, onto=t) + +pred_sets1(0.6, p, t) # cane (golden, labrador, cocker) +.predSets(lambda=0.6, pred=p, onto=t) + +p1 <- c(0.3, 0.4, 0.3, 0, 0, 0) +names(p1) <- c("cocker", "golden", "labrador", "british", "persiano", "topo") +pred_sets1(0.7, p1, t) +.predSets(lambda=0.7, pred=p1, onto=t) + +######################################### +### Create a random dataset +######################################### +set.seed(1010) +n <- 2100 + +leaves <- c(4,6,7,8,9) +nclasses <- length(leaves) +lab <- sample(leaves, n, replace=T) +x <- (runif(n,0,3.1) + lab) +y <- rep(NA, length(lab)) +y[lab==4] <- "cocker" +y[lab==6] <- "golden" +y[lab==7] <- "labrador" +y[lab==8] <- "british" +y[lab==9] <- "persiano" +xydf <- data.frame(y=as.factor(y), x=x) +head(xydf) + +# Divide in train, cal and test and fit model +library(VGAM) +train <- xydf[1:1000,] +cal <- xydf[1001:1100,] +test <- xydf[1101:2100,] + + +fit <- vglm(y~x, family = multinomial(refLevel = "cocker"), + data = train) +lambdas <- seq(0.001,0.999,length.out=100) +# Compute predicted probabilities for calibration points +p.cal <- predict(fit, newdata=cal, type="response") +library(doParallel) +library(foreach) +# +num_cores <- 4 +cl <- makeCluster(num_cores) +registerDoParallel(cl) + +t3 <- graph_from_literal(animale-+cane:gatto, gatto-+british:persiano, + gatto-+retriever, + cane-+cocker:retriever, retriever-+golden:labrador) +plot(t3, layout=layout_as_tree(t3, root="animale")) +exportedFn = c(".predSets", ".scores", ".children", ".ancestors") +system.time(sets1 <- foreach(lambda = lambdas, .packages = c("ConfCell", "igraph")) %dopar% { + library(igraph) + lapply(1:nrow(p.cal), function(i) pred_sets1(lambda, p.cal[i, ], t3)) +} +) + +# +system.time(l1 <- get_loss_table_mis(lambdas, sets1, as.character(cal$y), t3)) +Rhat <- colMeans(l1) +plot(Rhat) +all(diff(Rhat) <= 0) + +lhat <- get_lhat(l1,lambdas,0.1) +p.test <- predict(fit, newdata=test, type="response") +sets.test <- apply(p.test, 1, function(x) pred_sets1(lambda=lhat, x, t3)) + +t <- .getHierarchicalPredSets(p.cal, p.test, cal$y, onto = t3, alpha = 0.1, lambdas = lambdas) +t1 <- .getConformalPredSets(p.cal, p.test, as.character(cal$y), 0.1) +length(t1) +length(t$sets.test) + +l.std <- sapply(t1, length) +l.crc <- sapply(t$sets.test, length) +barplot(table(l.std)) +barplot(table(l.crc)) +cvg <- cvg1 <- rep(NA, length(t)) +for(i in 1:length(t1)) + cvg[i] <- test$y[i] %in% t1[[i]] + +for(i in 1:length(t$sets.test)) + cvg1[i] <- test$y[i] %in% t$sets.test[[i]] +mean(cvg) +mean(cvg1) + + + + + +sets <- foreach(j = lambdas, .export=exportedFn) %dopar% { + library(igraph) + lapply(1:nrow(p.cal), + function(i) .predSets(lambda=j, pred=p.cal[i, ], onto=t3)) +} + diff --git a/R/old/utils_Hier.R b/R/old/utils_Hier.R new file mode 100644 index 0000000..c02a84b --- /dev/null +++ b/R/old/utils_Hier.R @@ -0,0 +1,120 @@ +################################################################################ +###### File with utils functions for conformal risk control on graph structure +################################################################################ + + +# Conformal risk control +# Function to retrieve children nodes of a given interior node. +# If leaf=T gives leaf nodes, else gives all descendants +children <- function(node, graph, leaf=T){ + require(igraph) + if(leaf) + return(V(graph)$name[is.finite(distances(graph, node, mode="out")) & igraph::degree(graph, mode="out")==0]) + else + return(V(graph)$name[is.finite(distances(graph, node, mode="out"))]) +} + +# Function to retrieve ancestors of a given node +ancestors <- function(node, graph, include_self=T){ + require(igraph) + if(include_self) + return(V(graph)$name[is.finite(distances(graph, node, mode="in"))]) + else + return(V(graph)$name[is.finite(distances(graph, node, mode="in")) & V(graph)$name!=node]) +} + +# Define function to compute g(a,x), where a is a node +# (g(a,x)=sum_{y in P(a)}(pi(y))) +# @param pred named vector with predicted probabilities for one observation to be in one of the leaves class + +g <- function(pred, int_node, graph){ + c <- children(node = int_node, graph = graph, leaf = T) + return(sum(pred[c])) +} + +# Function to construct confidence sets (for one obs for now) +# @param lambda probability bound +# @param p predicted probabilities for leaf nodes +# @param graph graph + +pred_sets1 <- function(lambda, p, graph){ + pred_class <- names(p)[which.max(p)] + anc <- ancestors(node = pred_class, graph = graph, include_self = TRUE) + + # Obtain scores for ancestors of the predicted class + scores <- sapply(as.character(anc), function(i) g(p, i, graph)) + names(scores) <- anc + + # Sort them by score and if there are ties by distance to the predicted class + ## compute distance from predicted class + pos <- distances(graph, v=anc, to=pred_class, mode="out") + tie_breaker <- as.vector(t(pos)) + names(tie_breaker) <- colnames(t(pos)) + sorted_indices <- order(scores, tie_breaker, decreasing=F) + sorted_scores <- scores[sorted_indices] + + # Select the first score that is geq than lambda + sel_node <- names(sorted_scores)[round(sorted_scores, 15) >= lambda][1] + + # c(lapply(anc[round(scores, 15) <= lambda] + selected <- c(lapply(anc[round(scores, 15) < lambda], function(x) children(node = x, graph = graph)), + list(children(sel_node, graph))) + + return(Reduce(union, selected)) +} + +# Do not consider this +pred_sets2 <- function(lambda, p, graph){ + pred_class <- names(p)[which.max(p)] + anc <- ancestors(node = pred_class, graph = graph, include_self = TRUE) + + scores <- sapply(anc, function(i) g(p, i, graph)) + names(scores) <- anc + + selected <- lapply(anc[round(scores, 15) <= lambda], function(x) children(node = x, graph = graph)) + if(length(selected)==0) selected <- pred_class + + return(Reduce(union, selected)) +} + +# Loss table with miscoverage +get_loss_table_mis <- function(lambdas, prediction, ycal, graph) { + loss <- sapply(1:length(lambdas), function(lambda) { + sapply(seq_along(ycal), function(i) { + !(ycal[i] %in% prediction[[lambda]][[i]]) + }) + }) + + return(loss) +} + +# Hierarchical loss proposed by Angelopoulus and Bates +hier_loss <- function(set, true_class, graph){ + anc <- ancestors(true_class, graph) + set_distances <- distances(graph, v = anc, to = set) + + vroot <- V(graph)[degree(graph, mode = "in") == 0] + depth <- max(distances(graph, vroot, mode = "out")) + + return(min(set_distances) / depth) +} + +# Loss table with hierarchical loss +get_loss_table <- function(lambdas, prediction, ycal, graph) { + loss <- sapply(1:length(lambdas), function(lambda) { + sapply(seq_along(ycal), function(i) { + hier_loss(prediction[[lambda]][[i]], true_class = ycal[i], graph = graph) + }) + }) + + return(loss) +} + +# Find lambda hat on a grid of lambda values +get_lhat <- function(calib_loss_table, lambdas, alpha, B=1) { + n <- nrow(calib_loss_table) + rhat <- colMeans(calib_loss_table) + lhat_idx <- min(which(((n/(n+1)) * rhat + B/(n+1) ) <= alpha)) + # Return the corresponding lambda value at lhat_idx + return(lambdas[lhat_idx]) +} diff --git a/R/old/utils_Resample.R b/R/old/utils_Resample.R new file mode 100644 index 0000000..d975aa7 --- /dev/null +++ b/R/old/utils_Resample.R @@ -0,0 +1,177 @@ +########## Utils for resampling strategy + +resample.two <- function(cal, test, cal.pred, test.pred, seed=NA){ + if(!is.na(seed)) set.seed(seed) + s <- sample(1:nrow(test), round(nrow(test)/2)) + test1 <- test[s,] + test2 <- test[-s,] + p.test1 <- test.pred[s,] + p.test2 <- test.pred[-s,] + cal_freq <- prop.table(table(cal$Y)) + pr.class1 <- apply(p.test1, 1, function(row) colnames(p.test1)[which.max(row)]) + pr.class2 <- apply(p.test2, 1, function(row) colnames(p.test2)[which.max(row)]) + test_freq1 <- prop.table(table(pr.class1)) + test_freq2 <- prop.table(table(pr.class2)) + des_freq1 <- round(test_freq1 * length(cal$Y)) + des_freq2 <- round(test_freq2 * length(cal$Y)) + + idx1 <- idx2 <- NULL + for (i in cl.types) { + cat <- which(cal$Y == i) + if(!is.na(des_freq1[i])){ + idx_cat1 <- sample(cat, size = des_freq1[i], replace = TRUE) + idx1 <- c(idx1, idx_cat1) + } + if(!is.na(des_freq2[i])){ + idx_cat2 <- sample(cat, size = des_freq2[i], replace = TRUE) + idx2 <- c(idx2, idx_cat2) + } + } + + return(list(cal1=cal[idx1,], cal2=cal[idx2,], p.cal1=cal.pred[idx1,], + p.cal2=cal.pred[idx2,], test1=test1, test2=test2, + p.test1=p.test1, p.test2=p.test2)) +} + +# resample.oracle <- function(cal, test, cal.pred, test.pred, seed=NA){ +# if(!is.na(seed)) set.seed(seed) +# s <- sample(1:nrow(test), round(nrow(test)/2)) +# test1 <- test[s,] +# test2 <- test[-s,] +# p.test1 <- test.pred[s,] +# p.test2 <- test.pred[-s,] +# cal_freq <- prop.table(table(cal$Y)) +# test_freq1 <- prop.table(table(test1$Y)) +# test_freq2 <- prop.table(table(test2$Y)) +# des_freq1 <- round(test_freq1 * length(cal$Y)) +# des_freq2 <- round(test_freq2 * length(cal$Y)) +# +# idx1 <- idx2 <- NULL +# for (i in cl.types) { +# cat <- which(cal$Y == i) +# if(!is.na(des_freq1[i])){ +# idx_cat1 <- sample(cat, size = des_freq1[i], replace = TRUE) +# idx1 <- c(idx1, idx_cat1) +# } +# if(!is.na(des_freq2[i])){ +# idx_cat2 <- sample(cat, size = des_freq2[i], replace = TRUE) +# idx2 <- c(idx2, idx_cat2) +# } +# } +# +# return(list(cal1=cal[idx1,], cal2=cal[idx2,], p.cal1=cal.pred[idx1,], +# p.cal2=cal.pred[idx2,], test1=test1, test2=test2, +# p.test1=p.test1, p.test2=p.test2)) +# } + +resample.oracle <- function(cal, test, cal.pred, test.pred, seed=NA){ + if(!is.na(seed)) set.seed(seed) + + cal_freq <- prop.table(table(cal$Y)) + test_freq <- prop.table(table(test$Y)) + des_freq <- round(test_freq * length(cal$Y)) + + idx <- NULL + for (i in cl.types) { + cat <- which(cal$Y == i) + if(!is.na(des_freq[i])){ + idx_cat <- sample(cat, size = des_freq[i], replace = TRUE) + idx <- c(idx, idx_cat) + } + } + + return(list(cal=cal[idx,], p.cal=cal.pred[idx,])) +} + +resample.freq <- function(cal, test, cal.pred, test.pred, freqs, seed=NA){ + if(!is.na(seed)) set.seed(seed) + + cal_freq <- prop.table(table(cal$Y)) + test_freq <- freqs + des_freq <- round(test_freq * length(cal$Y)) + + idx <- NULL + for (i in cl.types) { + cat <- which(cal$Y == i) + if(!is.na(des_freq[i])){ + idx_cat <- sample(cat, size = des_freq[i], replace = TRUE) + idx <- c(idx, idx_cat) + } + } + + return(list(cal=cal[idx,], p.cal=cal.pred[idx,])) +} + + + + + +exportedFn <- c("pred_sets1", "children", "ancestors", "g") +conformal <- function(cal, test, p.cal, p.test, alpha=0.1, graph=opi2, + lambdas=seq(0.001,0.999,length.out=100), + expFn=exportedFn){ + conf.mln <- getConfQuant(p.cal, cal$Y, alpha) + pred.mln <- getPredSets(p.test, test$Y, conf.mln$qhat, + summary=T) + tf.conf <- rep(NA, length(pred.mln$pred.sets)) + for (i in 1:length(pred.mln$pred.sets)) + tf.conf[i] <- test$Y[i] %in% pred.mln$pred.sets[[i]] + + sets <- foreach(lambda = lambdas, .export = expFn) %dopar% { + lapply(1:nrow(p.cal), + function(i) pred_sets1(lambda, p.cal[i, ], graph))} + l <- get_loss_table_mis(lambdas, sets, as.character(cal$Y), graph) + lhat <- get_lhat(l, lambdas, alpha) + sets.test <- apply(p.test, 1, function(x) pred_sets1(lhat, x, graph)) + tf.crc <- rep(NA, nrow(test)) + for (j in 1:nrow(test)) + tf.crc[j] <- as.character(test$Y)[j] %in% sets.test[[j]] + return(list(tf.conf=tf.conf, tf.crc=tf.crc, sets.conf=pred.mln$pred.sets, sets.crc=sets.test)) + +} + +conf.loo <- function(p.cal, p.test, Ycal, Ytest, alpha){ + ntest <- nrow(p.test) + emp.cov <- rep(NA, ntest) + test.sets <- list() + for(i in 1:ntest){ + p.test.loo <- p.test[-i,] + cal_freq <- prop.table(table(Ycal)) + test_prfreq <- apply(p.test.loo, 1, function(row) colnames(p.test.loo)[which.max(row)]) + test_freq <- prop.table(table(test_prfreq)) + des_freq <- round(test_freq * length(Ycal)) + + # Resample + idx <- NULL + for (j in cl.types) { + cat <- which(Ycal == j) + if(!is.na(des_freq[j])){ + idx_cat <- sample(cat, size = des_freq[j], replace = TRUE) + idx <- c(idx, idx_cat) + } + } + p.cal.loo <- p.cal[idx,] + Ycal.loo <- Ycal[idx] + + # Conformal prediction for i obs + q <- getConfQuant(p.cal.loo, Ycal.loo, alpha) + p <- getPredSets(t(as.matrix(p.test[i,])), Ytest[i], q$qhat) + emp.cov[i] <- p$coverage + test.sets[i] <- p$pred.sets + #if(i%%100 == 0) cat(i) + } + return(list(emp.cov=mean(emp.cov), sets=test.sets)) +} + + + + + + + + + + + + + diff --git a/R/predSets.R b/R/predSets.R index 7d382bc..4581bbc 100644 --- a/R/predSets.R +++ b/R/predSets.R @@ -12,15 +12,23 @@ anc <- .ancestors(node = pred_class, onto = onto, include_self = TRUE) # Compute scores for all the ancestor of the predicted class - scores <- sapply(as.character(anc), function(i) scores(pred=pred, int_node=i, onto=onto)) - names(scores) <- anc + s <- sapply(as.character(anc), function(i) .scores(pred=pred, int_node=i, onto=onto)) + names(s) <- anc + + # Sort them by score and if there are ties by distance to the predicted class + ## compute distance from predicted class + pos <- distances(onto, v=anc, to=pred_class, mode="out") + tie_breaker <- as.vector(t(pos)) + names(tie_breaker) <- colnames(t(pos)) + sorted_indices <- order(s, tie_breaker, decreasing=F) + sorted_scores <- s[sorted_indices] + + # Select the first score that is geq than lambda + sel_node <- names(sorted_scores)[round(sorted_scores, 15) >= lambda][1] + - # Select nodes with scores higher than lambda - sel_scores <- scores[round(scores, 15) >= lambda] - # Select the most external node (i.e. the one with smallest score) - sel_node <- names(sel_scores)[length(sel_scores)] # Add also the subgraphs we would have obtained with smaller lambda - selected <- c(lapply(anc[round(scores, 15) <= lambda], function(x) .children(node = x, onto = onto)), + selected <- c(lapply(anc[round(s, 15) < lambda], function(x) .children(node = x, onto = onto)), list(.children(sel_node, onto))) return(Reduce(union, selected)) diff --git a/man/.DS_Store b/man/.DS_Store new file mode 100644 index 0000000..7171fac Binary files /dev/null and b/man/.DS_Store differ diff --git a/man/getPredictionSets.Rd b/man/getPredictionSets.Rd new file mode 100644 index 0000000..dde443c --- /dev/null +++ b/man/getPredictionSets.Rd @@ -0,0 +1,64 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/getPredictionSets.R +\name{getPredictionSets} +\alias{getPredictionSets} +\title{Get prediction sets} +\usage{ +getPredictionSets( + x.query, + x.cal, + y.cal, + onto = NULL, + alpha = 0.1, + lambdas = lambdas <- seq(0.001, 0.999, length.out = 100), + follow_ontology = TRUE +) +} +\arguments{ +\item{x.query}{query data for which we want to build prediction sets. Could be either a +SingleCellExperiment object with the estimated probabilities for each cell type +in the colData, or a matrix of dimension \code{n x K}, where \code{n} is the number +of cells and \code{K} is the number of different labels. The colnames of the matrix +have to correspond to the cell labels.} + +\item{x.cal}{calibration data. Could be either a +SingleCellExperiment object with the estimated probabilities for each cell type +in the colData, or a named matrix of dimension \code{m x K}, where \code{m} is the number +of cells and \code{K} is the number of different labels. The colnames of the matrix +have to correspond to the cell labels.} + +\item{y.cal}{a vector of length n with the labels of the cells in the calibration data} + +\item{onto}{the considered section of the cell ontology as an igraph object.} + +\item{alpha}{a number between 0 and 1 that indicates the allowed miscoverage} + +\item{lambdas}{a vector of possible lambda values to be considered} + +\item{follow_ontology}{If \code{TRUE}, then the function returns hierarchical +prediction sets that follow the cell ontology structure. If \code{FALSE}, it +returns classical conformal prediction sets.} +} +\value{ +The function \code{getPredictionSets} returns a list of length equal +to the number of cells in the test data. +Each element of the list contains the prediction set for that cell. +} +\description{ +Let K be the total number of distinct cell type labels and n, m +the number of cells in the calibration and in the test data, respectively. +This function takes as input two matrices: a matrix \code{n x K} and +a matrix \code{m x K} with the estimated +probabilities for each cell in the calibration and in the test data, respectively. +It returns a list with the prediction sets for each cell in the test data. +} +\references{ +For an introduction to conformal prediction, see of +Angelopoulos, Anastasios N., and Stephen Bates. "A gentle introduction to +conformal prediction and distribution-free uncertainty quantification." arXiv preprint arXiv:2107.07511 (2021). +For reference on conformal risk control, see +Angelopoulos, Anastasios N., et al. "Conformal risk control." arXiv preprint arXiv:2208.02814 (2022). +} +\author{ +Daniela Corbetta +} diff --git a/vignettes/.DS_Store b/vignettes/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/vignettes/.DS_Store differ diff --git a/vignettes/vignette.Rmd b/vignettes/vignette.Rmd index 8b13789..63ce3cb 100644 --- a/vignettes/vignette.Rmd +++ b/vignettes/vignette.Rmd @@ -1 +1,230 @@ +--- +title: "Conformal Prediction for cell type annotation" +author: + - name: Daniela Corbetta + affiliation: Department of Statistical Sciences, University of Padova + email: daniela.corbetta@phd.unipd.it +package: ConfCell +output: + BiocStyle::html_document: + toc: true + toc_float: true +vignette: > + %\VignetteIndexEntry{ConfCell} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = TRUE) +``` + +# Load useful packages + +```{r, message=FALSE} +library(ConfCell) +library(SingleCellExperiment) +library(scater) +library(scran) +library(VGAM) +library(foreach) +library(ontoProc) +library(MerfishData) +library(doParallel) +library(igraph) +`%notin%` = Negate(`%in%`) + +num_cores <- 4 +cl1 <- makeCluster(num_cores) +registerDoParallel(cl1) +``` + +# Preliminaries + +## Load data and cell ontology + +As an example, we'll load the mouse ileum Merfish data from the MerfishData Bioconductor package. Load also the cell ontology trough the ontoProc Bioc package +```{r} +spe.baysor <- MouseIleumPetukhov2021(segmentation = "baysor") +cl <- getOnto("cellOnto", "2023") +``` + +## Build the ontology + +To restrict the cell ontology to the interesting part that is related to the cell types +that are present in our data, we need to find the corresponding tags: +```{r} +tags <- c("CL:0009022", + "CL:0000236", + "CL:0009080", + "CL:1000411", + "CL:1000335", + "CL:1000326", + "CL:0002088", + "CL:0009007", + "CL:1000343", + "CL:0000669", + "CL:1000278", + "CL:0009017", + "CL:0000492", + "CL:0000625", + "CL:0017004") +opi <- graph_from_graphnel(onto_plot2(cl, tags)) +``` + +In the \texttt{opi} object, there are also instances from other ontologies (CARO and BFO). +Remove them and rename the vertex to have them match the annotations. + +```{r} +sel_ver <- V(opi)$name[c(grep("CARO", V(opi)$name), grep("BFO", V(opi)$name))] +opi1 <- opi - sel_ver + +## Rename vertex to match annotations +V(opi1)$name[V(opi1)$name=="B\ncell\nCL:0000236"] <- "B cell" +V(opi1)$name[V(opi1)$name=="endothelial\ncell of Peyer's patch\nCL:1000411"] <- "Endothelial" +V(opi1)$name[V(opi1)$name=="enterocyte\nof epithelium of intestinal villus\nCL:1000335"] <- "Enterocyte" +V(opi1)$name[V(opi1)$name=="ileal\ngoblet cell\nCL:1000326"] <- "Goblet" +V(opi1)$name[V(opi1)$name=="interstitial\ncell of Cajal\nCL:0002088"] <- "ICC" +V(opi1)$name[V(opi1)$name=="gastrointestinal\ntract (lamina propria) macrophage of small intestine\nCL:0009007"] <- "Macrophage + DC" +V(opi1)$name[V(opi1)$name=="paneth\ncell of epithelium of small intestine\nCL:1000343"] <- "Paneth" +V(opi1)$name[V(opi1)$name=="pericyte\nCL:0000669"] <- "Pericyte" +V(opi1)$name[V(opi1)$name=="smooth\nmuscle fiber of ileum\nCL:1000278"] <- "Smooth Muscle" +V(opi1)$name[V(opi1)$name=="intestinal\ncrypt stem cell of small intestine\nCL:0009017"] <- "Stem + TA" +V(opi1)$name[V(opi1)$name=="stromal\ncell of lamina propria of small intestine\nCL:0009022"] <- "Stromal" +V(opi1)$name[V(opi1)$name=="CD4-positive\nhelper T cell\nCL:0000492"] <- "T (CD4+)" +V(opi1)$name[V(opi1)$name=="CD8-positive,\nalpha-beta T cell\nCL:0000625"] <- "T (CD8+)" +V(opi1)$name[V(opi1)$name=="telocyte\nCL:0017004"] <- "Telocyte" +V(opi1)$name[V(opi1)$name=="tuft\ncell of small intestine\nCL:0009080"] <- "Tuft" + +## Add the edge from connective tissue cell and telocyte and delete useless nodes +opi1 <- add_edges(opi1, c("connective\ntissue cell\nCL:0002320", "Telocyte")) +gr <- as_graphnel(opi1) +opi2 <- opi1 - c("somatic\ncell\nCL:0002371", "contractile\ncell\nCL:0000183", + "native\ncell\nCL:0000003") +gr1 <- as_graphnel(opi2) +plot(gr1, attrs=list(node=list(fontsize=27))) +``` + +## Preprocess data + +Modify the colData variable "leiden_final" to unify B cells and enterocytes + +```{r} +spe.baysor$cell_type <- spe.baysor$leiden_final +spe.baysor$cell_type[spe.baysor$cell_type %in% c("B (Follicular, Circulating)", "B (Plasma)")] <- "B cell" +spe.baysor$cell_type[grep("Enterocyte", spe.baysor$cell_type)] <- "Enterocyte" +spe.baysor <- spe.baysor[,spe.baysor$cell_type %notin% c("Removed", "Myenteric Plexus")] +spe.baysor +``` + + +## Fit model +Fit a multinomial model with a random sample of 500 cells using the 50 HVGs. +```{r, warning=FALSE} +# get. HVGs +spe.baysor <- logNormCounts(spe.baysor) +v <- modelGeneVar(spe.baysor) +hvg <- getTopHVGs(v, n=50) + +# Extract counts and construct df +df <- as.data.frame(t(as.matrix(counts(spe.baysor[hvg,])))) +df$Y <- spe.baysor$cell_type +set.seed(1703) +train <- sample(1:nrow(df), 500) +df.train <- df[train,] +table(df.train$Y) + +# Fit model +fit <- vglm(Y ~ ., family = multinomial(refLevel = "B cell"), + data = df.train) +df.other <- df[-train,] +spe.other <- spe.baysor[,-train] +``` + +# Prediction sets + +Split data randomly in calibration and query data +```{r} +# split data +set.seed(1237) +cal <- sample(1:nrow(df.other), 1000) +df.cal <- df.other[cal,] +df.test <- df.other[-cal,] +``` + + +## Obtain prediction matrices + +The first step to build conformal prediction sets is to obtain prediction matrices +for data in the calibration data and in the query data. Each row of the matrix +correspons to a particular cell, while each row to a different cell type. The entry +$p_{i,j}$ of the matrix indicates the estimated probability that the cell $i$ is +of type $j$. +```{r} +# Prediction matrix for calibration data +p.cal <- predict(fit, newdata=df.cal, type="response") +# Prediction matrix for query data +p.test <- predict(fit, newdata=df.test, type="response") +``` + + +## Obtain conformal prediction sets + +We can now directly call the \texttt{getPredictionSet} function by using as input the +prediction matrices for the calibration and the query dataset. In this case, the +output of the function will be a list whose elements are the prediction sets for each query cell. +By setting \texttt{follow_ontology=FALSE}, we are asking the function to return +prediction sets obtained via standard conformal inference. + +```{r} +sets <- getPredictionSets(x.query=p.test, x.cal=p.cal, y.cal = df.cal$Y, + onto = opi2, alpha = 0.1, + follow_ontology = FALSE) +``` + +As an alternative, we can provide as input a SingleCellExperiment object. In this case, +it needs to have in the colData the estimated probabilities for each cell type. +The names of these colData have to correspond to the names of the leaf nodes in the +ontology. Let's create the two separate SingleCellExperiment objects and add the predictions to the colData. + +```{r} +# Create the SingleCellExperiment objects +spe.cal <- spe.other[,cal] +spe.test <- spe.other[,-cal] + +# Retrieve labels as leaf nodes of the ontology +labels <- V(opi2)$name[degree(opi2, mode="out")==0] + +# Create corresponding colData +for(i in labels){ + colData(spe.cal)[[i]] <- p.cal[,i] + colData(spe.test)[[i]] <- p.test[,i] +} + +# Create prediction sets +sets.bis <- getPredictionSets(x.query=spe.test, x.cal=spe.cal, y.cal = df.cal$Y, + onto = opi2, alpha = 0.1, + follow_ontology = FALSE) + +l <- sapply(sets, length) +l.bis <- sapply(sets.bis, length) +mean(l) +barplot(table(l.bis)) + +# l.onto <- sapply(sets.onto, length) +# mean(l.onto) +``` + + + + +```{r} +sets.onto <- getPredictionSets(x.query=p.test, x.cal=p.cal, y.cal = df.cal$Y, + onto = opi2, alpha = 0.1, + lambdas = seq(0.001,0.999,length.out=100), + follow_ontology = TRUE) +``` + + +