Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of .lvCompare() #425

Merged
merged 15 commits into from
Nov 4, 2024
Merged
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* import `lifecycle`, `purrr`, `withr`
* suppressed "using discrete variable for alpha is not recommended" warning in alluvialClones unit tests.
* Fixed issue with ```clonalCluster()``` and exportGraph = TRUE
* improve performance of ```combineBCR()``` by a constant factor with C++

# scRepertoire VERSION 2.0.7

Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ rcppConstructConDfAndParseTCR <- function(data2, uniqueData2Barcodes) {
.Call(`_scRepertoire_rcppConstructConDfAndParseTCR`, data2, uniqueData2Barcodes)
}

rcppGetSigSequenceEditDistEdgeListDf <- function(sequences, threshold) {
.Call(`_scRepertoire_rcppGetSigSequenceEditDistEdgeListDf`, sequences, threshold)
}

rcppGenerateUniqueNtMotifs <- function(k) {
.Call(`_scRepertoire_rcppGenerateUniqueNtMotifs`, k)
}
Expand Down
4 changes: 2 additions & 2 deletions R/clonalCluster.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ clonalCluster <- function(input.data,
}
colnames(output.list[[x]]) <- c(paste0(chain, "_cluster"), ref2)
output.list[[x]]

})
cluster <- bind_rows(cluster.list)
cluster <- bind_rows(cluster.list) # the TRA_cluster isnt assigned in the failing test
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ncborcherding Could you maybe describe what exactly this intermediate variable is supposed to be? Struggling to debug downstream atm.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Qile0317 Here the cluster.list is the cluster assignments from the .lv.compare() function that are grouped into list elements, bind_rows() is just forming a data frame from the list.


#Merging with contig info
tmp <- bound
Expand Down
83 changes: 42 additions & 41 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -587,50 +587,15 @@ is_df_or_list_of_df <- function(x) {
dictionary[dictionary == "None"] <- NA
dictionary$v.gene <- stringr::str_split(dictionary[,gene], "[.]", simplify = TRUE)[,1]
tmp <- na.omit(unique(dictionary[,c(chain, "v.gene")]))

#chunking the sequences for distance by v.gene

edge.list <- lapply(str_sort(na.omit(unique(dictionary$v.gene)), numeric = T), function(v_gene) {
filtered_df <- dplyr::filter(dictionary, v.gene == v_gene)
nucleotides <- filtered_df[[chain]]
nucleotides <- sort(unique(str_split(nucleotides, ";", simplify = TRUE)[,1]))

if (length(nucleotides) <= 1) return(NULL)

results <- list()
# Only iterate until the second last element to avoid the issue
for (i in 1:(length(nucleotides) - 1)) {
for (j in (i + 1):length(nucleotides)) {
# Check based on length difference feasibility
if (abs(nchar(nucleotides[i]) - nchar(nucleotides[j])) > max(nchar(nucleotides[i]), nchar(nucleotides[j])) * (1 - threshold)) {
next
}

distance <- stringdist::stringdist(nucleotides[i], nucleotides[j], method = "lv")
normalized_distance <- 1 - distance / mean(c(nchar(nucleotides[i]), nchar(nucleotides[j])))

if (normalized_distance >= threshold) {
results[[length(results) + 1]] <- data.frame(
from = nucleotides[i],
to = nucleotides[j],
distance = normalized_distance
)
}
}
}

do.call(rbind, results)
})

edge.list <- do.call(rbind, edge.list)

edge.list <- .createBcrEdgeListDf(dictionary, chain, threshold)

if(exportGraph) {
if(!is.null(edge.list)) {
graph <- graph_from_edgelist(as.matrix(edge.list)[,c(1,2)])

} else {
graph <- NULL
if (is.null(edge.list)) {
return(NULL)
}
return(graph)
return(graph_from_edgelist(as.matrix(edge.list)[, c(1, 2)]))
}

if (!is.null(dim(edge.list))) {
Expand All @@ -657,3 +622,39 @@ is_df_or_list_of_df <- function(x) {
return(output)
}

#' create an edge list dataframe from clustering BCR data with edit distance
#'
#' This is a helper for the internal [.lvCompare()] that constructs a directed
#' graph edge list as a dataframe, with weights being the normalized edit
#' distance between two v genes.
#'
#' @param dictionary row binded loadContigs output with a column `v.gene`
#' @param chain string. Which v gene type.
#' @param threshold single numeric between 0 and 1
#'
#' @return A data.frame with the columns `to`, `from`, `distance`. To and from
#' represents a directed edge between two v genes, and the distance being the
#' normalized edit distance. Those that exceed the threshold are filtered.
#
#' @keywords internal
#' @noRd
.createBcrEdgeListDf <- function(dictionary, chain, threshold) {

vGenes <- str_sort(na.omit(unique(dictionary$v.gene)), numeric = TRUE)

edge.list <- lapply(vGenes, function(v_gene) {
filtered_df <- dplyr::filter(dictionary, v.gene == v_gene)
nucleotides <- filtered_df[[chain]]
nucleotides <- sort(unique(str_split(nucleotides, ";", simplify = TRUE)[, 1]))
if (length(nucleotides) <= 1) return(NULL)
out <- rcppGetSigSequenceEditDistEdgeListDf(nucleotides, threshold)
#print(out)
out
})

edgeListDf <- do.call(rbind, edge.list)
if (!is.null(edgeListDf) && nrow(edgeListDf) == 0) {
return(NULL)
}
edgeListDf
}
13 changes: 13 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// rcppGetSigSequenceEditDistEdgeListDf
Rcpp::DataFrame rcppGetSigSequenceEditDistEdgeListDf(const std::vector<std::string> sequences, const double threshold);
RcppExport SEXP _scRepertoire_rcppGetSigSequenceEditDistEdgeListDf(SEXP sequencesSEXP, SEXP thresholdSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const std::vector<std::string> >::type sequences(sequencesSEXP);
Rcpp::traits::input_parameter< const double >::type threshold(thresholdSEXP);
rcpp_result_gen = Rcpp::wrap(rcppGetSigSequenceEditDistEdgeListDf(sequences, threshold));
return rcpp_result_gen;
END_RCPP
}
// rcppGenerateUniqueNtMotifs
Rcpp::CharacterVector rcppGenerateUniqueNtMotifs(int k);
RcppExport SEXP _scRepertoire_rcppGenerateUniqueNtMotifs(SEXP kSEXP) {
Expand Down Expand Up @@ -75,6 +87,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_scRepertoire_rcppGetAaKmerPercent", (DL_FUNC) &_scRepertoire_rcppGetAaKmerPercent, 3},
{"_scRepertoire_rcppConstructBarcodeIndex", (DL_FUNC) &_scRepertoire_rcppConstructBarcodeIndex, 2},
{"_scRepertoire_rcppConstructConDfAndParseTCR", (DL_FUNC) &_scRepertoire_rcppConstructConDfAndParseTCR, 2},
{"_scRepertoire_rcppGetSigSequenceEditDistEdgeListDf", (DL_FUNC) &_scRepertoire_rcppGetSigSequenceEditDistEdgeListDf, 2},
{"_scRepertoire_rcppGenerateUniqueNtMotifs", (DL_FUNC) &_scRepertoire_rcppGenerateUniqueNtMotifs, 1},
{"_scRepertoire_rcppGetNtKmerPercent", (DL_FUNC) &_scRepertoire_rcppGetNtKmerPercent, 2},
{NULL, NULL, 0}
Expand Down
86 changes: 86 additions & 0 deletions src/lvCompare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <Rcpp.h>
#include <string>
#include <vector>

template <typename T>
T min(T a, T b, T c) {
return std::min(a, std::min(b, c));
}

double editDist(const std::string& s1, const std::string& s2) {

const int n = s2.size();
const int m = s1.size();

if (m == 0) return static_cast<double>(n);
if (n == 0) return static_cast<double>(m);

std::vector<int> prev(n + 1), curr(n + 1);

for (int j = 0; j <= n; j++) {
prev[j] = j;
}

for (int i = 1; i <= m; i++) {

curr[0] = i;

for (int j = 1; j <= n; j++) {
curr[j] = min(
curr[j - 1] + 1,
prev[j] + 1,
prev[j - 1] + ((s1[i - 1] == s2[j - 1]) ? 0 : 1)
);
}

std::swap(prev, curr);
}

return static_cast<double>(prev[n]);
}

bool lenDiffWithinThreshold(const int len1, const int len2, const double threshold) {
double lenDiff = static_cast<double>(std::abs(len1 - len2));
double maxLen = static_cast<double>(std::max(len1, len2));
return lenDiff <= (maxLen * (1 - threshold));
}

// [[Rcpp::export]]
Rcpp::DataFrame rcppGetSigSequenceEditDistEdgeListDf(
const std::vector<std::string> sequences, const double threshold
) {

std::vector<std::string> from, to;
std::vector<double> distances;

for (size_t i = 0; i < (sequences.size() - 1); i++) {
for (size_t j = i + 1; j < sequences.size(); j++) {

int len1 = sequences[i].size();
int len2 = sequences[j].size();

if (!lenDiffWithinThreshold(len1, len2, threshold)) {
continue;
}

double meanLen = static_cast<double>(len1 + len2) * 0.5;
double normalizedDistance = 1 - (editDist(sequences[i], sequences[j]) / meanLen);

if (normalizedDistance < threshold) {
continue;
}

from.push_back(sequences[i]);
to.push_back(sequences[j]);
distances.push_back(normalizedDistance);

}
}

return Rcpp::DataFrame::create(
Rcpp::Named("from") = from,
Rcpp::Named("to") = to,
Rcpp::Named("distance") = distances
);

}
38 changes: 32 additions & 6 deletions tests/testthat/test-clonalCluster.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# test script for clonalCluster.R - testcases are NOT comprehensive!

getTestCombinedSeuratObj <- function() {
test_obj <- combineExpression(getCombined(), get(data("scRep_example")))
test_obj$Patient <- substr(test_obj$orig.ident,1,3)
test_obj$Type <- substr(test_obj$orig.ident,4,4)
test_obj
}

test_that("clonalCluster works", {

data("scRep_example")
test_obj <- getTestCombinedSeuratObj()
combined <- getCombined()
test_obj <- combineExpression(combined, scRep_example)
test_obj$Patient <- substr(test_obj$orig.ident,1,3)
test_obj$Type <- substr(test_obj$orig.ident,4,4)

set.seed(42)
withr::local_seed(42)

expect_equal(
clonalCluster(combined[[1]],
chain = "TRB",
Expand Down Expand Up @@ -46,4 +50,26 @@ test_that("clonalCluster works", {
)

})

# test_that("clonalCluster works with custom threshold", { # fails atm

# test_obj <- getTestCombinedSeuratObj()
# withr::local_seed(42)

# colorblind_vector <- hcl.colors(n=7, palette = "inferno", fixup = TRUE)

# test_obj <- clonalCluster(test_obj,
# chain = "TRA",
# sequence = "aa",
# threshold = 0.85,
# group.by = "Patient")

# Seurat::DimPlot(scRep_example, group.by = "TRA_cluster") +
# scale_color_manual(values = hcl.colors(n=length(unique([email protected][,"TRA_cluster"])), "inferno")) +
# Seurat::NoLegend() +
# theme(plot.title = element_blank()) %>%
# print()

# })

#TODO Add exportgraph test
Loading