Skip to content

Commit

Permalink
resample
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielaCorbetta committed May 9, 2024
1 parent 614fd06 commit c2d48c2
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 22 deletions.
70 changes: 55 additions & 15 deletions R/getPredictionSets.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
#' calibration data
#' @param onto the considered section of the cell ontology as an igraph object.
#' @param alpha a number between 0 and 1 that indicates the allowed miscoverage
#' @param lambdas a vector of possible lambda values to be considered
#' @param lambdas a vector of possible lambda values to be considered. Necessary
#' only when \code{follow_ontology=TRUE}.
#' @param follow_ontology If \code{TRUE}, then the function returns hierarchical
#' prediction sets that follow the cell ontology structure. If \code{FALSE}, it
#' returns classical conformal prediction sets.
#' @param resample_cal Should the calibration dataset be resampled according to
#' the estimated relative frequencies of cell types in the query data?
#' @param labels labels of different considered cell types. Necessary if
#' \code{onto=NULL}, otherwise they are set to the leaf nodes of the provided
#' graph
Expand All @@ -36,10 +39,10 @@
#' sets.
#' @author Daniela Corbetta
#' @return The function \code{getPredictionSets} returns
#' \item{If \code{return.sc=TRUE}}{a SingleCellExperiment or SpatialExperiment
#' \item{\code{return.sc=TRUE}}{a SingleCellExperiment or SpatialExperiment
#' object with the prediction sets in the colData. The name of the variable
#' containing the prediction sets is given by the parameter \code{pr.name}}
#' \item{If \code{return.sc=FALSE}}{a list of length equal
#' \item{\code{return.sc=FALSE}}{a list of length equal
#' to the number of cells in the test data. Each element of the list contains
#' the prediction set for that cell.}
#' @references For an introduction to conformal prediction, see of
Expand All @@ -56,7 +59,9 @@

getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
lambdas = seq(0.001,0.999,length.out=100),
follow_ontology=TRUE, labels=NULL,
follow_ontology=TRUE,
resample_cal=FALSE,
labels=NULL,
return.sc=NULL, pr.name="pred.set"){
if(follow_ontology & is.null(onto)){
stop("An ontology is required for hierarchical prediction set.
Expand All @@ -80,12 +85,14 @@ getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
sc <- FALSE
else stop("Please provide as input in x.query a SpatialExperiment,
SingleCellExperiment or a matrix")
cat(sc)
# Add check to see if x.cal, x.query are SingleCell/SpatialExperiment or matrices

# Retrieve labels from the ontology (need to add retrieval from y.cal/data
# when follow_ontology=FALSE)
labels <- V(onto)$name[degree(onto, mode="out")==0]
if(is.null(labels))
labels <- V(onto)$name[degree(onto, mode="out")==0]
K <- length(labels)

# If input is not a matrix, retrieve prediction matrix from colData
if(!is.matrix(x.query)){
n.query <- ncol(x.query)
p.query <- matrix(NA, nrow=n.query, ncol=K)
Expand All @@ -106,15 +113,48 @@ getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
}
else p.cal <- x.cal

if (follow_ontology){
pred.sets <- .getHierarchicalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
if(!resample_cal){
if (follow_ontology){
pred.sets <- .getHierarchicalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
}
else
pred.sets <- .getConformalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, alpha=alpha)
}
else
pred.sets <- .getConformalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, alpha=alpha)

if(resample_cal){
data <- resample.two(p.cal=p.cal, p.test=p.query, y.cal=y.cal,
labels=labels)
if (follow_ontology){
pred.sets1 <- .getHierarchicalPredSets(p.cal=data$p.cal1,
p.test=data$p.test2,
y.cal=data$y.cal1,
onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
pred.sets2 <- .getHierarchicalPredSets(p.cal=data$p.cal2,
p.test=data$p.test1,
y.cal=data$y.cal2,
onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
pred.sets <- c(pred.sets1, pred_sets2)
}
else
pred.sets1 <- .getConformalPredSets(p.cal=data$p.cal1,
p.test=data$p.test2,
y.cal=data$y.cal1,
alpha=alpha)
pred.sets2 <- .getConformalPredSets(p.cal=data$p.cal2,
p.test=data$p.test1,
y.cal=data$y.cal2,
alpha=alpha)
pred.sets <- c(pred.sets1, pred_sets2)
} # Problem: this prediction sets are not ordered, must be ordered before
# assigning them to the colData

# if not specified, return a sc object if the input was a sc object,
# a matrix if the input was a matrix
Expand Down
10 changes: 6 additions & 4 deletions R/predSets.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,25 @@
anc <- .ancestors(node = pred_class, onto = onto, include_self = TRUE)

# Compute scores for all the ancestor of the predicted class
s <- sapply(as.character(anc), function(i) .scores(pred=pred, int_node=i, onto=onto))
s <- sapply(as.character(anc), function(i) .scores(pred=pred, int_node=i,
onto=onto))
names(s) <- anc

# Sort them by score and if there are ties by distance to the predicted class
#Sort them by score and if there are ties by distance to the predicted class
## compute distance from predicted class
pos <- distances(onto, v=anc, to=pred_class, mode="out")
tie_breaker <- as.vector(t(pos))
names(tie_breaker) <- colnames(t(pos))
sorted_indices <- order(s, tie_breaker, decreasing=F)
sorted_indices <- order(s, tie_breaker, decreasing=FALSE)
sorted_scores <- s[sorted_indices]

# Select the first score that is geq than lambda
sel_node <- names(sorted_scores)[round(sorted_scores, 15) >= lambda][1]


# Add also the subgraphs we would have obtained with smaller lambda
selected <- c(lapply(anc[round(s, 15) < lambda], function(x) .children(node = x, onto = onto)),
selected <- c(lapply(anc[round(s, 15) < lambda], function(x)
.children(node = x, onto = onto)),
list(.children(sel_node, onto)))

return(Reduce(union, selected))
Expand Down
37 changes: 37 additions & 0 deletions R/resample.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

resample.two <- function(p.cal, p.test, y.cal, labels){
s <- sample(1:nrow(test), round(nrow(test)/2))
test1 <- p.test[s,]
test2 <- p.test[-s,]
# cal_freq <- prop.table(table(y.cal))
# Compute predicted class
pr.class1 <- apply(test1, 1, function(row) colnames(test1)[which.max(row)])
pr.class2 <- apply(test2, 1, function(row) colnames(test2)[which.max(row)])
test_freq1 <- prop.table(table(pr.class1))
test_freq2 <- prop.table(table(pr.class2))
# Transform to absolute frequencies
des_freq1 <- round(test_freq1 * length(y.cal))
des_freq2 <- round(test_freq2 * length(y.cal))

idx1 <- idx2 <- NULL
for (i in labels) {
cat <- which(y.cal == i)
if(!is.na(des_freq1[i])){
idx_cat1 <- sample(cat, size = des_freq1[i], replace = TRUE)
idx1 <- c(idx1, idx_cat1)
}
if(!is.na(des_freq2[i])){
idx_cat2 <- sample(cat, size = des_freq2[i], replace = TRUE)
idx2 <- c(idx2, idx_cat2)
}
}

return(list(p.cal1=p.cal[idx1,],
p.cal2=p.cal[idx2,],
p.test1=p.test1,
p.test2=p.test2,
y.cal1=y.cal[idx1],
y.cal2=y.cal[idx2])
)
}

11 changes: 8 additions & 3 deletions man/getPredictionSets.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c2d48c2

Please sign in to comment.