Skip to content

Commit

Permalink
Update lmutils::combine_vectors to support matrices and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mrvillage committed Nov 17, 2024
1 parent a78af49 commit e9acf0c
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 31 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export(internal_lmutils_file_into_fd)
export(linear_regression)
export(load)
export(load_matrix)
export(logistic_regression)
export(map_from_pairs)
export(match_rows)
export(match_rows_dir)
Expand Down
14 changes: 11 additions & 3 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,24 @@ calculate_r2 <- function(data, outcomes) .Call(wrap__calculate_r2, data, outcome
#' @export
column_p_values <- function(data, outcomes) .Call(wrap__column_p_values, data, outcomes)

#' Compute a linear regression between each matrix in a list and a each column in another matrix.
#' Compute a linear regression between each matrix in a list and each column in another matrix.
#' `data` is a list of matrix convertable objects.
#' `outcomes` is a matrix convertable object.
#' Returns a data frame with columns `slopes`, `intercept`, `predicted` (if enabled), `r2`,
#' `adj_r2`, `data`, `outcome`, `n`, and `m`.
#' @export
linear_regression <- function(data, outcomes) .Call(wrap__linear_regression, data, outcomes)

#' Combine a list of double vectors into a matrix.
#' `data` is a list of double vectors.
#' Compute a logistic regression between each matrix in a list and each column in another matrix.
#' `data` is a list of matrix convertable objects.
#' `outcomes` is a matrix convertable object.
#' Returns a data frame with columns `slopes`, `intercept`, `predicted` (if enabled), `r2`,
#' `adj_r2`, `data`, `outcome`, `n`, and `m`.
#' @export
logistic_regression <- function(data, outcomes) .Call(wrap__logistic_regression, data, outcomes)

#' Combine a list of double vectors or matrices into a matrix.
#' `data` is a list of double vectors or matrices.
#' `out` is an output file name or `NULL` to return the matrix.
#' @export
combine_vectors <- function(data, out) .Call(wrap__combine_vectors, data, out)
Expand Down
8 changes: 4 additions & 4 deletions man/combine_vectors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/linear_regression.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions man/logistic_regression.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

71 changes: 49 additions & 22 deletions src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,36 +766,63 @@ pub fn logistic_regression(data: Robj, outcomes: Robj) -> Result<Robj> {
Ok(df.into_robj())
}

/// Combine a list of double vectors into a matrix.
/// `data` is a list of double vectors.
/// Combine a list of double vectors or matrices into a matrix.
/// `data` is a list of double vectors or matrices.
/// `out` is an output file name or `NULL` to return the matrix.
/// @export
#[extendr]
pub fn combine_vectors(data: List, out: Nullable<&str>) -> Result<Nullable<RMatrix<f64>>> {
init();

let ncols = data.len();
let nrows = data
.iter()
.map(|(_, v)| {
v.as_real_slice()
.expect("all vectors must be doubles")
.len()
})
.next()
.unwrap_or(0);
let data = data.iter().map(|(_, v)| Par(v)).collect::<Vec<_>>();
let mut mat = vec![MaybeUninit::uninit(); ncols * nrows];
mat.par_chunks_exact_mut(nrows)
.zip(data.into_par_iter())
.for_each(|(data, v)| {
let v = v.as_real_slice().expect("all vectors must be doubles");
if v.len() != nrows {
let mut nrows = 0;
let mut chunks = Vec::with_capacity(data.len());
let mut ncols = 0;
for (_, v) in data {
if v.is_matrix() {
let mat = v.as_matrix::<f64>().unwrap();
if nrows == 0 {
nrows = mat.nrows();
} else if mat.nrows() != nrows {
panic!("all matrices must have the same number of rows");
}
chunks.push((ncols, mat.nrows(), Par(v)));
ncols += mat.ncols();
} else if v.is_real() {
if nrows == 0 {
nrows = v.len();
} else if v.as_real_slice().unwrap().len() != nrows {
panic!("all vectors must have the same length");
}
let v: &[MaybeUninit<f64>] = unsafe { std::mem::transmute(v) };
data.copy_from_slice(v);
});
chunks.push((ncols, 1, Par(v)));
ncols += 1;
} else {
panic!("data must be a list of double vectors or matrices");
}
}
if nrows == 0 {
panic!("data must contain at least one vector or matrix");
}
if chunks.is_empty() {
panic!("data must contain at least one vector or matrix");
}

let mat = vec![MaybeUninit::uninit(); ncols * nrows];
chunks.into_par_iter().for_each(|(start, ncols, v)| {
if v.is_matrix() {
let v = v.as_matrix::<f64>().unwrap();
let v = v.data();
let slice =
unsafe { std::slice::from_raw_parts_mut(mat.as_ptr().cast_mut(), mat.len()) };
slice[start * nrows..(start + ncols) * nrows]
.copy_from_slice(unsafe { std::mem::transmute::<&[f64], &[MaybeUninit<f64>]>(v) });
} else if v.is_real() {
let v = v.as_real_slice().unwrap();
let slice =
unsafe { std::slice::from_raw_parts_mut(mat.as_ptr().cast_mut(), mat.len()) };
slice[start * nrows..(start + ncols) * nrows]
.copy_from_slice(unsafe { std::mem::transmute::<&[f64], &[MaybeUninit<f64>]>(v) });
}
});

let mut mat = Matrix::Owned(OwnedMatrix::new(
nrows,
Expand Down

0 comments on commit e9acf0c

Please sign in to comment.