Skip to content

Commit

Permalink
Add upper triangular solve
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKnodt committed Nov 27, 2023
1 parent b03a14b commit 5841cb0
Showing 1 changed file with 60 additions and 19 deletions.
79 changes: 60 additions & 19 deletions nalgebra-sparse/src/csc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,55 @@ impl<T> CscMatrix<T> {
/// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector
/// sparse.
/// sparsity_idxs should be precomputed using the sparse_lower_triangle.
/// Assumes that the diagonal of the sparse matrix is all 1.
pub fn sparse_upper_triangular_solve_sorted(
&self,
b_idxs: &[usize],
b: &[T],

out_sparsity_pattern: &[usize],
out: &mut [T],
) where
T: RealField + Copy,
{
assert_eq!(self.nrows(), self.ncols());
assert_eq!(b_idxs.len(), b.len());
assert!(b_idxs.iter().all(|&bi| bi < self.ncols()));

assert_eq!(out_sparsity_pattern.len(), out.len());
assert!(out_sparsity_pattern.iter().all(|&bi| bi < self.ncols()));

// initialize out with b
out.fill(T::zero());
for (&bv, &bi) in b.iter().zip(b_idxs.iter()) {
let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap();
out[out_pos] = bv;
}

for (i, &row) in out_sparsity_pattern.iter().enumerate().rev() {
let col = self.col(row);
let mut iter = col
.row_indices()
.iter()
.zip(col.values().iter())
.rev()
.peekable();
let mul = out[i];
for (ni, &nrow) in out_sparsity_pattern.iter().enumerate().rev().skip(i + 1) {
assert!(nrow < row);
while iter.next_if(|n| *n.0 > nrow).is_some() {}
let l_val = match iter.peek() {
Some((&r, &l_val)) if r == nrow => l_val,
_ => continue,
};
out[ni] -= l_val * mul;
}
}
}

/// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector
/// sparse.
/// sparsity_idxs should be precomputed using the sparse_lower_triangle.
/// Assumes that the diagonal of the sparse matrix is all 1 if `assume_unit` is true.
pub fn sparse_lower_triangular_solve(
&self,
b_idxs: &[usize],
Expand Down Expand Up @@ -621,6 +669,7 @@ impl<T> CscMatrix<T> {
}

// initialize out with b
out.fill(T::zero());
for (&bv, &bi) in b.iter().zip(b_idxs.iter()) {
let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap();
out[out_pos] = bv;
Expand Down Expand Up @@ -654,8 +703,11 @@ impl<T> CscMatrix<T> {
}
/// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector
/// sparse.
/// sparsity_idxs should be precomputed using the sparse_lower_triangle.
/// Assumes that the diagonal of the sparse matrix is all 1.
/// sparsity_idxs should be precomputed using the sparse_lower_triangle pattern.
///
/// `out_sparsity_pattern` must also be pre-sorted.
///
/// Assumes that the diagonal of the sparse matrix is all 1 if `assume_unit` is true.
pub fn sparse_lower_triangular_solve_sorted(
&self,
// input vector idxs & values
Expand All @@ -679,6 +731,7 @@ impl<T> CscMatrix<T> {

// initialize out with b
// TODO can make this more efficient by keeping two iterators in sorted order
out.fill(T::zero());
for (&bv, &bi) in b.iter().zip(b_idxs.iter()) {
let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap();
out[out_pos] = bv;
Expand All @@ -691,13 +744,7 @@ impl<T> CscMatrix<T> {
let col = self.col(row);
let mut iter = col.row_indices().iter().zip(col.values().iter()).peekable();
if !assume_unit {
while let Some(n) = iter.peek() {
if *n.0 < row {
iter.next();
} else {
break;
}
}
while iter.next_if(|n| *n.0 < row).is_some() {}
match iter.peek() {
Some((&r, &l_val)) if r == row => out[i] /= l_val,
// here it now becomes implicitly 0,
Expand All @@ -706,20 +753,14 @@ impl<T> CscMatrix<T> {
}
}
let mul = out[i];
for (offset, &nrow) in out_sparsity_pattern[i..].iter().enumerate().skip(1) {
for (ni, &nrow) in out_sparsity_pattern.iter().enumerate().skip(i + 1) {
assert!(nrow > row);
while let Some(n) = iter.peek() {
if *n.0 < nrow {
iter.next();
} else {
break;
}
}
while iter.next_if(|n| *n.0 < nrow).is_some() {}
let l_val = match iter.peek() {
Some((&r, &l_val)) if r == nrow => l_val,
_ => continue,
};
out[i + offset] -= l_val * mul;
out[ni] -= l_val * mul;
}
}
}
Expand Down

0 comments on commit 5841cb0

Please sign in to comment.