Skip to content

Commit

Permalink
solved bug order resample
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielaCorbetta committed May 10, 2024
1 parent c2d48c2 commit d395259
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 67 deletions.
Binary file modified .DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Package: ConfCell
Package: scConform
Type: Package
Title: Conformal Inference For Cell Type Annotation
Version: 0.99.0
Expand Down
92 changes: 48 additions & 44 deletions R/getPredictionSets.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
#' @param onto the considered section of the cell ontology as an igraph object.
#' @param alpha a number between 0 and 1 that indicates the allowed miscoverage
#' @param lambdas a vector of possible lambda values to be considered. Necessary
#' only when \code{follow_ontology=TRUE}.
#' @param follow_ontology If \code{TRUE}, then the function returns hierarchical
#' only when \code{follow.ontology=TRUE}.
#' @param follow.ontology If \code{TRUE}, then the function returns hierarchical
#' prediction sets that follow the cell ontology structure. If \code{FALSE}, it
#' returns classical conformal prediction sets.
#' @param resample_cal Should the calibration dataset be resampled according to
#' @param resample.cal Should the calibration dataset be resampled according to
#' the estimated relative frequencies of cell types in the query data?
#' @param labels labels of different considered cell types. Necessary if
#' \code{onto=NULL}, otherwise they are set to the leaf nodes of the provided
Expand Down Expand Up @@ -59,14 +59,14 @@

getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
lambdas = seq(0.001,0.999,length.out=100),
follow_ontology=TRUE,
resample_cal=FALSE,
follow.ontology=TRUE,
resample.cal=FALSE,
labels=NULL,
return.sc=NULL, pr.name="pred.set"){
if(follow_ontology & is.null(onto)){
if(follow.ontology & is.null(onto)){
stop("An ontology is required for hierarchical prediction set.
Please provide one or ask for conformal prediction set
(follow_ontology=FALSE)")
(follow.ontology=FALSE)")
}
if(is.null(onto) & is.null(labels)){
stop("Please provide cell labels with the labels parameter")
Expand All @@ -87,7 +87,7 @@ getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
SingleCellExperiment or a matrix")

# Retrieve labels from the ontology (need to add retrieval from y.cal/data
# when follow_ontology=FALSE)
# when follow.ontology=FALSE)
if(is.null(labels))
labels <- V(onto)$name[degree(onto, mode="out")==0]
K <- length(labels)
Expand All @@ -98,23 +98,23 @@ getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
p.query <- matrix(NA, nrow=n.query, ncol=K)
colnames(p.query) <- labels
for(i in labels){
p.query[,i] <- colData(x.query)[[i]]
p.query[,i] <- colData(x.query)[[i]]
}
}
else p.query <- x.query

if(!is.matrix(x.cal)){
n.cal <- ncol(x.cal)
p.cal <- matrix(NA, nrow=n.cal, ncol=K)
colnames(p.cal) <- labels
for(i in labels){
p.cal[,i] <- colData(x.cal)[[i]]
}
n.cal <- ncol(x.cal)
p.cal <- matrix(NA, nrow=n.cal, ncol=K)
colnames(p.cal) <- labels
for(i in labels){
p.cal[,i] <- colData(x.cal)[[i]]
}
}
else p.cal <- x.cal

if(!resample_cal){
if (follow_ontology){
if(!resample.cal){
if (follow.ontology){
pred.sets <- .getHierarchicalPredSets(p.cal=p.cal, p.test=p.query,
y.cal=y.cal, onto=onto,
alpha=alpha,
Expand All @@ -125,36 +125,40 @@ getPredictionSets <- function(x.query, x.cal, y.cal, onto=NULL, alpha = 0.1,
y.cal=y.cal, alpha=alpha)
}

if(resample_cal){
if(resample.cal){
data <- resample.two(p.cal=p.cal, p.test=p.query, y.cal=y.cal,
labels=labels)
if (follow_ontology){
pred.sets1 <- .getHierarchicalPredSets(p.cal=data$p.cal1,
p.test=data$p.test2,
y.cal=data$y.cal1,
onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
pred.sets2 <- .getHierarchicalPredSets(p.cal=data$p.cal2,
p.test=data$p.test1,
y.cal=data$y.cal2,
onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
pred.sets <- c(pred.sets1, pred_sets2)
if (follow.ontology){
pred.sets1 <- .getHierarchicalPredSets(p.cal=data$p.cal2,
p.test=data$p.test1,
y.cal=data$y.cal2,
onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
pred.sets2 <- .getHierarchicalPredSets(p.cal=data$p.cal1,
p.test=data$p.test2,
y.cal=data$y.cal1,
onto=onto,
alpha=alpha,
lambdas=lambdas)$sets.test
pred.sets <- c(pred.sets1, pred.sets2)
}
else
pred.sets1 <- .getConformalPredSets(p.cal=data$p.cal1,
p.test=data$p.test2,
y.cal=data$y.cal1,
alpha=alpha)
pred.sets2 <- .getConformalPredSets(p.cal=data$p.cal2,
p.test=data$p.test1,
y.cal=data$y.cal2,
alpha=alpha)
pred.sets <- c(pred.sets1, pred_sets2)
} # Problem: this prediction sets are not ordered, must be ordered before
# assigning them to the colData
else {
pred.sets1 <- .getConformalPredSets(p.cal=data$p.cal2,
p.test=data$p.test1,
y.cal=data$y.cal2,
alpha=alpha)
pred.sets2 <- .getConformalPredSets(p.cal=data$p.cal1,
p.test=data$p.test2,
y.cal=data$y.cal1,
alpha=alpha)
pred.sets <- c(pred.sets1, pred.sets2)
}
# Order the prediction set
pred.sets <- pred.sets[order(data$idx)]
}



# if not specified, return a sc object if the input was a sc object,
# a matrix if the input was a matrix
Expand Down
31 changes: 16 additions & 15 deletions R/resample.R
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@

resample.two <- function(p.cal, p.test, y.cal, labels){
s <- sample(1:nrow(test), round(nrow(test)/2))
s <- sample(1:nrow(p.test), round(nrow(p.test)/2))
test1 <- p.test[s,]
test2 <- p.test[-s,]
# cal_freq <- prop.table(table(y.cal))

# Compute predicted class
pr.class1 <- apply(test1, 1, function(row) colnames(test1)[which.max(row)])
pr.class2 <- apply(test2, 1, function(row) colnames(test2)[which.max(row)])
test_freq1 <- prop.table(table(pr.class1))
test_freq2 <- prop.table(table(pr.class2))
test.freq1 <- prop.table(table(pr.class1))
test.freq2 <- prop.table(table(pr.class2))
# Transform to absolute frequencies
des_freq1 <- round(test_freq1 * length(y.cal))
des_freq2 <- round(test_freq2 * length(y.cal))
des.freq1 <- round(test.freq1 * length(y.cal))
des.freq2 <- round(test.freq2 * length(y.cal))

idx1 <- idx2 <- NULL
for (i in labels) {
cat <- which(y.cal == i)
if(!is.na(des_freq1[i])){
idx_cat1 <- sample(cat, size = des_freq1[i], replace = TRUE)
idx1 <- c(idx1, idx_cat1)
if(!is.na(des.freq1[i])){
idx.cat1 <- sample(cat, size = des.freq1[i], replace = TRUE)
idx1 <- c(idx1, idx.cat1)
}
if(!is.na(des_freq2[i])){
idx_cat2 <- sample(cat, size = des_freq2[i], replace = TRUE)
idx2 <- c(idx2, idx_cat2)
if(!is.na(des.freq2[i])){
idx.cat2 <- sample(cat, size = des.freq2[i], replace = TRUE)
idx2 <- c(idx2, idx.cat2)
}
}

return(list(p.cal1=p.cal[idx1,],
p.cal2=p.cal[idx2,],
p.test1=p.test1,
p.test2=p.test2,
p.test1=test1,
p.test2=test2,
y.cal1=y.cal[idx1],
y.cal2=y.cal[idx2])
y.cal2=y.cal[idx2],
idx=c(s, setdiff(1:nrow(p.test), s)))#index in the original data
)
}

10 changes: 5 additions & 5 deletions man/getPredictionSets.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions vignettes/vignette.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ author:
- name: Daniela Corbetta
affiliation: Department of Statistical Sciences, University of Padova
email: [email protected]
package: scConform
output:
BiocStyle::html_document:
toc: true
toc_float: true
vignette: >
%\VignetteIndexEntry{my-vignette}
%\VignetteIndexEntry{scConform}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
Expand All @@ -22,5 +23,5 @@ knitr::opts_chunk$set(
```

```{r setup}
library(ConfCell)
library(scConform)
```

0 comments on commit d395259

Please sign in to comment.