Skip to content

Commit

Permalink
python interface and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rabbull committed Jan 20, 2025
1 parent ad8898c commit cb227db
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 42 deletions.
20 changes: 19 additions & 1 deletion src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use pyo3::{
PyTupleMethods, PyType, PyTypeMethods,
},
wrap_pyfunction, Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr,
PyObject, PyRef, PyResult, PyTypeInfo, Python,
PyObject, PyRef, PyRefMut, PyResult, PyTypeInfo, Python,
};
use pyo3::{pyclass, types::PyModuleMethods};
use rug::Complete;
Expand Down Expand Up @@ -10741,6 +10741,24 @@ impl PythonMatrix {
}
}

/// Permutes the rows of the matrix based on the provided permutation vector.
pub fn permute_rows(&self, permutation_vector: Vec<u32>) -> PyResult<Self> {
let permuted = self
.matrix
.permute_rows(&permutation_vector)
.map_err(|e| exceptions::PyValueError::new_err(format!("{}", e)))?;
Ok(Self { matrix: permuted })
}

/// Perform LU decomposition over the matrix.
pub fn lu_decompose(&self) -> PyResult<(Vec<u32>, Self, Self)> {
let (pv, l, u) = self
.matrix
.lu_decompose()
.map_err(|e| exceptions::PyValueError::new_err(format!("{}", e)))?;
Ok((pv, Self { matrix: l }, Self { matrix: u }))
}

/// Apply a function `f` to every entry of the matrix.
pub fn map(&self, f: PyObject) -> PyResult<PythonMatrix> {
let data = self
Expand Down
6 changes: 5 additions & 1 deletion src/poly/gcd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,11 @@ impl<F: Field, E: PositiveExponent> MultivariatePolynomial<F, E> {
return Err(GCDError::BadOriginalImage);
}
Err(
MatrixError::NotSquare
MatrixError::IndexOutOfBounds {
indices: _,
shape: _,
}
| MatrixError::NotSquare
| MatrixError::ShapeMismatch
| MatrixError::RightHandSideIsNotVector
| MatrixError::Singular
Expand Down
78 changes: 38 additions & 40 deletions src/tensors/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,30 +971,30 @@ impl<F: Ring> Matrix<F> {
}

/// Permutes the rows of the matrix based on the provided permutation vector.
pub fn permute_rows(&self, pv: &[u32]) -> Self {
assert_eq!(
self.nrows as usize,
pv.len(),
"Permutation vector length must equal the number of rows."
);
pub fn permute_rows(&self, pv: &[u32]) -> Result<Self, MatrixError<F>> {
if self.nrows as usize != pv.len() {
return Err(MatrixError::ShapeMismatch);
}

let mut data = Vec::with_capacity(self.data.len());
for row_index in pv {
assert!(
row_index.lt(&Integer::from(self.nrows)),
"Row index out of bounds in permutation vector."
);
if row_index.ge(&Integer::from(self.nrows)) {
return Err(MatrixError::IndexOutOfBounds {
indices: vec![row_index.clone()],
shape: vec![self.nrows],
});
}
let start = row_index * self.ncols;
let end = &start + self.ncols;
data.extend_from_slice(&self.data[start.to_u64() as usize..end.to_u64() as usize]);
}

Matrix {
Ok(Matrix {
ncols: self.ncols,
nrows: self.nrows,
data: data,
field: self.field.clone(),
}
})
}
}

Expand Down Expand Up @@ -1271,6 +1271,10 @@ pub enum MatrixError<F: Ring> {
ShapeMismatch,
RightHandSideIsNotVector,
ResultNotInDomain,
IndexOutOfBounds {
indices: Vec<u32>,
shape: Vec<u32>,
},
}

impl<F: Ring> std::fmt::Display for MatrixError<F> {
Expand Down Expand Up @@ -1298,6 +1302,12 @@ impl<F: Ring> std::fmt::Display for MatrixError<F> {
f,
"The result does not belong to the same domain as the matrix."
),
MatrixError::IndexOutOfBounds { indices, shape } => write!(
f,
"Index out of bounds: tried to access element at {:?}, but the matrix has shape {:?}.",
indices,
shape
),
}
}
}
Expand Down Expand Up @@ -2072,38 +2082,26 @@ mod test {
fn test_matrix_permutation() {
let a = Matrix::from_linear(
vec![
1.1.into(),
1.2.into(),
1.3.into(),
2.1.into(),
2.2.into(),
2.3.into(),
3.1.into(),
3.2.into(),
3.3.into(),
11.into(),
12.into(),
13.into(),
21.into(),
22.into(),
23.into(),
31.into(),
32.into(),
33.into(),
],
3,
3,
Q,
Z,
)
.unwrap();

let pv = vec![2 as u32, 0 as u32, 1 as u32];
let permuted = a.permute_rows(&pv);
assert_eq!(
permuted.data,
[
3.1.into(),
3.2.into(),
3.3.into(),
1.1.into(),
1.2.into(),
1.3.into(),
2.1.into(),
2.2.into(),
2.3.into()
]
);
let permuted = a.permute_rows(&pv).unwrap();
println!("{}", permuted);
assert_eq!(permuted.data, [31, 32, 33, 11, 12, 13, 21, 22, 23]);
}

#[test]
Expand Down Expand Up @@ -2142,7 +2140,7 @@ mod test {
)
.unwrap();
let (res_pv, res_l, res_u) = m2.lu_decompose().unwrap();
assert_eq!(res_l.mul(&res_u), m2.permute_rows(&res_pv));
assert_eq!(res_l.mul(&res_u), m2.permute_rows(&res_pv).unwrap());

let field = AtomField {
cancel_check_on_division: true,
Expand All @@ -2168,7 +2166,7 @@ mod test {
let (res_pv, res_l, res_u) = m3.lu_decompose().unwrap();

let prod = res_l.mul(&res_u);
let perm = m3.permute_rows(&res_pv);
let perm = m3.permute_rows(&res_pv).unwrap();
for i in 0..3 {
for j in 0..3 {
let lhs = &prod[(i, j)];
Expand All @@ -2187,6 +2185,6 @@ mod test {
)
.unwrap();
let (res_pv, res_l, res_u) = m4.lu_decompose().unwrap();
assert_eq!(res_l.mul(&res_u), m4.permute_rows(&res_pv));
assert_eq!(res_l.mul(&res_u), m4.permute_rows(&res_pv).unwrap());
}
}

0 comments on commit cb227db

Please sign in to comment.