Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to the latest tatami interface for the M/Cov validity check. #143

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Imports:
Biostrings,
utils,
HDF5Array (>= 1.19.11),
rhdf5
rhdf5,
beachmat (>= 2.23.2)
Suggests:
testthat,
bsseqData,
Expand All @@ -49,7 +50,6 @@ Suggests:
doParallel,
rtracklayer,
BSgenome.Hsapiens.UCSC.hg38,
beachmat (>= 1.5.2),
batchtools
Collate:
utils.R
Expand Down Expand Up @@ -82,6 +82,10 @@ VignetteBuilder: knitr
URL: https://github.com/kasperdanielhansen/bsseq
BugReports: https://github.com/kasperdanielhansen/bsseq/issues
biocViews: DNAMethylation
LinkingTo: Rcpp, beachmat
LinkingTo:
Rcpp,
beachmat,
assorthead (>= 1.1.4)
SystemRequirements: C++17
NeedsCompilation: yes
RoxygenNote: 7.1.0
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ importMethodsFrom(GenomeInfoDb, "seqlengths", "seqlengths<-", "seqinfo",
"seqlevels<-", "sortSeqlevels")
importFrom(GenomeInfoDb, "Seqinfo")
importFrom(gtools, "combinations")
importFrom(beachmat, initializeCpp)
importFrom(Rcpp, sourceCpp)

# NOTE: data.table has some NAMESPACE clashes with functions in Bioconductor,
Expand Down
2 changes: 1 addition & 1 deletion R/BSseq-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

.checkMandCov <- function(M, Cov) {
msg <- NULL
validMsg(msg, .Call(cxx_check_M_and_Cov, M, Cov))
validMsg(msg, .Call(cxx_check_M_and_Cov, initializeCpp(M), initializeCpp(Cov), 1))
}

# TODO: Benchmark validity method
Expand Down
2 changes: 1 addition & 1 deletion src/BSseq.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extern "C" {

// Validity checking.

SEXP check_M_and_Cov(SEXP, SEXP);
SEXP check_M_and_Cov(SEXP, SEXP, SEXP);
}

#endif
135 changes: 61 additions & 74 deletions src/check_M_and_Cov.cpp
Original file line number Diff line number Diff line change
@@ -1,107 +1,94 @@
#include "BSseq.h"

#include "beachmat/integer_matrix.h"
#include "beachmat/numeric_matrix.h"
#include "Rtatami.h"

#include "utils.h"

#include <string>
#include <vector>

// NOTE: Returning Rcpp::CharacterVector rather than throwing an error because
// this function is used within a validity method.

template <class M_column_class, class Cov_column_class, class M_class,
class Cov_class>
Rcpp::RObject check_M_and_Cov_internal(M_class M_bm, Cov_class Cov_bm) {
SEXP check_M_and_Cov(SEXP M, SEXP Cov, SEXP nt) {
BEGIN_RCPP

Rtatami::BoundNumericPointer M_bound(M);
const auto& M_bm = *(M_bound->ptr);
Rtatami::BoundNumericPointer Cov_bound(Cov);
const auto& Cov_bm = *(Cov_bound->ptr);

// Get the dimensions of 'M' and 'Cov' and check these are compatible.
const size_t M_nrow = M_bm->get_nrow();
const size_t Cov_nrow = Cov_bm->get_nrow();
const int M_nrow = M_bm.nrow();
const int Cov_nrow = Cov_bm.nrow();
if (M_nrow != Cov_nrow) {
return Rcpp::CharacterVector(
"'M' and 'Cov' must have the same number of rows.");
}
const size_t M_ncol = M_bm->get_ncol();
const size_t Cov_ncol = Cov_bm->get_ncol();
const int M_ncol = M_bm.ncol();
const int Cov_ncol = Cov_bm.ncol();
if (M_ncol != Cov_ncol) {
return Rcpp::CharacterVector(
"'M' and 'Cov' must have the same number of columns.");
}

Rcpp::IntegerVector raw_nt(nt);
if (raw_nt.size() != 1 || raw_nt[0] <= 0) {
return Rcpp::CharacterVector(
"Number of threads should be a positive integer.");
}
int nthreads = raw_nt[0];

// Simultaneously loop over columns of 'M' and 'Cov', checking that
// `all(0 <= M <= Cov) && !anyNA(M) && !anyNA(Cov)` && all(is.finite(Cov)).
M_column_class M_column(M_nrow);
Cov_column_class Cov_column(Cov_nrow);
for (size_t j = 0; j < M_ncol; ++j) {
// Copy the j-th column of M to M_column and the j-th column of Cov to
// Cov_column
M_bm->get_col(j, M_column.begin());
Cov_bm->get_col(j, Cov_column.begin());
// Construct iterators
// NOTE: Iterators constructed outside of loop because they may be of
// different type, which is not supported within a for loop
// constructor.
auto M_column_it = M_column.begin();
auto Cov_column_it = Cov_column.begin();
for (M_column_it = M_column.begin(), Cov_column_it = Cov_column.begin();
M_column_it != M_column.end();
++M_column_it, ++Cov_column_it) {
if (isNA(*M_column_it)) {
return Rcpp::CharacterVector("'M' must not contain NAs.");
}
if (isNA(*Cov_column_it)) {
return Rcpp::CharacterVector("'Cov' must not contain NAs.");
}
if (*M_column_it < 0) {
return Rcpp::CharacterVector(
"'M' must not contain negative values.");
}
if (*M_column_it > *Cov_column_it) {
return Rcpp::CharacterVector(
"All values of 'M' must be less than or equal to the corresponding value of 'Cov'.");
}
if (!R_FINITE(*Cov_column_it)) {
return Rcpp::CharacterVector("All values of 'Cov' must be finite.");
std::vector<std::string> errors(nthreads);
tatami::parallelize([&](int tid, int start, int length) {
std::vector<double> M_buffer(M_nrow), Cov_buffer(Cov_nrow);
auto M_ext = tatami::consecutive_extractor<false>(&M_bm, false, start, length);
auto Cov_ext = tatami::consecutive_extractor<false>(&Cov_bm, false, start, length);

for (int c = start, cend = start + length; c < cend; ++c) {
auto M_ptr = M_ext->fetch(M_buffer.data());
auto Cov_ptr = Cov_ext->fetch(Cov_buffer.data());

for (int r = 0; r < M_nrow; ++r) {
auto M_current = M_ptr[r];
auto Cov_current = Cov_ptr[r];

if (isNA(M_current)) {
errors[tid] = "'M' must not contain NAs.";
return;
}
if (isNA(Cov_current)) {
errors[tid] = "'Cov' must not contain NAs.";
return;
}
if (M_current < 0) {
errors[tid] = "'M' must not contain negative values.";
return;
}
if (M_current > Cov_current) {
errors[tid] = "All values of 'M' must be less than or equal to the corresponding value of 'Cov'.";
return;
}
if (!R_FINITE(Cov_current)) {
errors[tid] = "All values of 'Cov' must be finite.";
return;
}
}
}
}, M_ncol, nthreads);

for (const auto& msg : errors) {
if (!msg.empty()) {
return Rcpp::CharacterVector(msg.c_str());
}
}

return R_NilValue;
END_RCPP
}

SEXP check_M_and_Cov(SEXP M, SEXP Cov) {
BEGIN_RCPP

// Get the type of 'M' and 'Cov',
int M_type = beachmat::find_sexp_type(M);
int Cov_type = beachmat::find_sexp_type(Cov);
if (M_type == INTSXP && Cov_type == INTSXP) {
auto M_bm = beachmat::create_integer_matrix(M);
auto Cov_bm = beachmat::create_integer_matrix(Cov);
return check_M_and_Cov_internal<
Rcpp::IntegerVector, Rcpp::IntegerVector>(M_bm.get(), Cov_bm.get());
} else if (M_type == REALSXP && Cov_type == REALSXP) {
auto M_bm = beachmat::create_numeric_matrix(M);
auto Cov_bm = beachmat::create_numeric_matrix(Cov);
return check_M_and_Cov_internal<
Rcpp::NumericVector, Rcpp::NumericVector>(M_bm.get(), Cov_bm.get());
} else if (M_type == INTSXP && Cov_type == REALSXP) {
auto M_bm = beachmat::create_integer_matrix(M);
auto Cov_bm = beachmat::create_numeric_matrix(Cov);
return check_M_and_Cov_internal<
Rcpp::IntegerVector, Rcpp::NumericVector>(M_bm.get(), Cov_bm.get());
} else if (M_type == REALSXP && Cov_type == INTSXP) {
auto M_bm = beachmat::create_numeric_matrix(M);
auto Cov_bm = beachmat::create_integer_matrix(Cov);
return check_M_and_Cov_internal<
Rcpp::NumericVector, Rcpp::IntegerVector>(M_bm.get(), Cov_bm.get());
}
else {
return Rcpp::CharacterVector(
"'M' and 'Cov' must contain integer or numeric values.");
}
END_RCPP
}
// TODOs -----------------------------------------------------------------------

// TODO: Add code path to process ordinary R vectors (for use within
Expand Down
2 changes: 1 addition & 1 deletion src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ extern "C" {

static const R_CallMethodDef all_call_entries[] = {
// Validity checking.
REGISTER(check_M_and_Cov, 2),
REGISTER(check_M_and_Cov, 3),
{NULL, NULL, 0}
};

Expand Down