Skip to content

Commit

Permalink
#9 more refactoring, working on diffsl now
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Apr 27, 2024
1 parent ec880a7 commit 9e99ee8
Show file tree
Hide file tree
Showing 19 changed files with 370 additions and 311 deletions.
3 changes: 1 addition & 2 deletions src/jacobian/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::vector::Vector;
use crate::Scalar;
use crate::{matrix::DenseMatrix, op::NonLinearOp, VectorViewMut, MatrixSparsity, Matrix};
use crate::op::NonLinearOp;
use num_traits::{One, Zero};

use self::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy};
Expand Down Expand Up @@ -89,7 +89,6 @@ impl JacobianColoring {
mod tests {
use std::rc::Rc;

use crate::matrix::MatrixSparsity;
use crate::op::Op;
use crate::{
jacobian::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy},
Expand Down
44 changes: 25 additions & 19 deletions src/linear_solver/faer/lu.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,58 @@
use crate::{linear_solver::LinearSolver, op::LinearOp, solver::SolverProblem, Scalar};
use crate::{
linear_solver::LinearSolver, op::linearise::LinearisedOp, solver::SolverProblem, NonLinearOp,
Op, Scalar,
};
use anyhow::Result;
use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat};
/// A [LinearSolver] that uses the LU decomposition in the [`faer`](https://github.com/sarah-ek/faer-rs) library to solve the linear system.
pub struct LU<T, C>
where
T: Scalar,
C: LinearOp<M = Mat<T>, V = Col<T>, T = T>,
C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>,
{
lu: Option<FullPivLu<T>>,
problem: Option<SolverProblem<C>>,
problem: Option<SolverProblem<LinearisedOp<C>>>,
matrix: Option<Mat<T>>,
}

impl<T, C> Default for LU<T, C>
where
T: Scalar,
C: LinearOp<M = Mat<T>, V = Col<T>, T = T>,
C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>,
{
fn default() -> Self {
Self {
lu: None,
problem: None,
matrix: None,
}
}
}

impl<T: Scalar, C: LinearOp<M = Mat<T>, V = Col<T>, T = T>> LinearSolver<C> for LU<T, C> {
fn problem(&self) -> Option<&SolverProblem<C>> {
self.problem.as_ref()
}
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>> {
self.problem.as_mut()
}
fn take_problem(&mut self) -> Option<SolverProblem<C>> {
self.lu = None;
Option::take(&mut self.problem)
impl<T: Scalar, C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>> LinearSolver<C> for LU<T, C> {
fn set_linearisation(&mut self, x: &C::V, t: C::T) {
let matrix = self.matrix.as_mut().expect("Matrix not set");
let problem = self.problem.as_ref().expect("Problem not set");
problem.f.jacobian_inplace(x, t, matrix);
self.lu = Some(matrix.full_piv_lu());
}

fn solve_in_place(&self, state: &mut C::V) -> Result<()> {
fn solve_in_place(&self, x: &mut C::V) -> Result<()> {
if self.lu.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
}
let lu = self.lu.as_ref().unwrap();
lu.solve_in_place(state);
lu.solve_in_place(x);
Ok(())
}

fn set_problem(&mut self, problem: SolverProblem<C>) {
self.lu = Some(problem.f.jacobian(problem.t).full_piv_lu());
self.problem = Some(problem);
fn set_problem(&mut self, problem: &SolverProblem<C>) {
let linearised_problem = problem.linearise();
let matrix = Mat::zeros(
linearised_problem.f.nstates(),
linearised_problem.f.nstates(),
);
self.problem = Some(linearised_problem);
self.matrix = Some(matrix);
}
}
10 changes: 2 additions & 8 deletions src/linear_solver/gmres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,16 @@ impl<C: LinearOp> LinearSolver<C> for GMRES<C>
where
for<'b> &'b C::V: VectorRef<C::V>,
{
fn problem(&self) -> Option<&SolverProblem<C>> {
todo!()
}
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>> {
todo!()
}

fn take_problem(&mut self) -> Option<SolverProblem<C>> {
fn set_linearisation(&mut self, x: &<C as crate::op::Op>::V, t: <C as crate::op::Op>::T) {
todo!()
}

fn solve_in_place(&self, _state: &mut C::V) -> Result<()> {
todo!()
}

fn set_problem(&mut self, _problem: SolverProblem<C>) {
fn set_problem(&mut self, _problem: &SolverProblem<C>) {
todo!()
}
}
38 changes: 11 additions & 27 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,17 @@ pub mod sundials;
pub use faer::lu::LU as FaerLU;
pub use nalgebra::lu::LU as NalgebraLU;

/// A solver for the linear problem `Ax = b`.
/// The solver is parameterised by the type `C` which is the type of the linear operator `A` (see the [Op] trait for more details).
/// A solver for the linear problem `Ax = b`, where `A` is a linear operator that is obtained by taking the linearisation of a nonlinear operator `C`
pub trait LinearSolver<C: Op> {
/// Set the problem to be solved, any previous problem is discarded.
/// Any internal state of the solver is reset.
fn set_problem(&mut self, problem: SolverProblem<C>);
fn set_problem(&mut self, problem: &SolverProblem<C>);

/// Get a reference to the current problem, if any.
fn problem(&self) -> Option<&SolverProblem<C>>;

/// Get a mutable reference to the current problem, if any.
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>>;

/// Take the current problem, if any, and return it.
fn take_problem(&mut self) -> Option<SolverProblem<C>>;

fn reset(&mut self) {
if let Some(problem) = self.take_problem() {
self.set_problem(problem);
}
}
// sets the point at which the linearisation of the operator is evaluated
fn set_linearisation(&mut self, x: &C::V, t: C::T);

/// Solve the problem `Ax = b` and return the solution `x`.
/// panics if [set_linearisation] has not been called previously
fn solve(&self, b: &C::V) -> Result<C::V> {
let mut b = b.clone();
self.solve_in_place(&mut b)?;
Expand Down Expand Up @@ -69,7 +57,7 @@ pub mod tests {
vector::VectorRef,
DenseMatrix, LinearSolver, SolverProblem, Vector,
};
use num_traits::{One, Zero};
use num_traits::One;

use super::LinearSolveSolution;

Expand All @@ -81,16 +69,15 @@ pub mod tests {
let jac = M::from_diagonal(&diagonal);
let p = Rc::new(M::V::zeros(0));
let op = Rc::new(LinearClosure::new(
// f = J * x
move |x, _p, _t, y| jac.gemv(M::T::one(), x, M::T::zero(), y),
// f = J * x + beta * y
move |x, _p, _t, beta, y| jac.gemv(M::T::one(), x, beta, y),
2,
2,
p,
));
let t = M::T::zero();
let rtol = M::T::from(1e-6);
let atol = Rc::new(M::V::from_vec(vec![1e-6.into(), 1e-6.into()]));
let problem = SolverProblem::new(op, t, atol, rtol);
let problem = SolverProblem::new(op, atol, rtol);
let solns = vec![LinearSolveSolution::new(
M::V::from_vec(vec![2.0.into(), 4.0.into()]),
M::V::from_vec(vec![1.0.into(), 2.0.into()]),
Expand All @@ -106,13 +93,10 @@ pub mod tests {
C: LinearOp,
for<'a> &'a C::V: VectorRef<C::V>,
{
solver.set_problem(problem);
solver.set_problem(&problem);
for soln in solns {
let x = solver.solve(&soln.b).unwrap();
let tol = {
let problem = solver.problem().unwrap();
&soln.x * scale(problem.rtol) + problem.atol.as_ref()
};
let tol = { &soln.x * scale(problem.rtol) + problem.atol.as_ref() };
x.assert_eq(&soln.x, &tol);
}
}
Expand Down
46 changes: 27 additions & 19 deletions src/linear_solver/nalgebra/lu.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,39 @@
use anyhow::Result;
use nalgebra::{DMatrix, DVector, Dyn};

use crate::{op::LinearOp, LinearSolver, Scalar, SolverProblem};
use crate::{
op::{linearise::LinearisedOp, NonLinearOp},
LinearSolver, Op, Scalar, SolverProblem,
};

/// A [LinearSolver] that uses the LU decomposition in the [`nalgebra` library](https://nalgebra.org/) to solve the linear system.
pub struct LU<T, C>
where
T: Scalar,
C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
{
matrix: Option<DMatrix<T>>,
lu: Option<nalgebra::LU<T, Dyn, Dyn>>,
problem: Option<SolverProblem<C>>,
problem: Option<SolverProblem<LinearisedOp<C>>>,
}

impl<T, C> Default for LU<T, C>
where
T: Scalar,
C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
{
fn default() -> Self {
Self {
lu: None,
problem: None,
matrix: None,
}
}
}

impl<T: Scalar, C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver<C> for LU<T, C> {
fn problem(&self) -> Option<&SolverProblem<C>> {
self.problem.as_ref()
}
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>> {
self.problem.as_mut()
}
fn take_problem(&mut self) -> Option<SolverProblem<C>> {
self.lu = None;
Option::take(&mut self.problem)
}

impl<T: Scalar, C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver<C>
for LU<T, C>
{
fn solve_in_place(&self, state: &mut C::V) -> Result<()> {
if self.lu.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
Expand All @@ -49,8 +45,20 @@ impl<T: Scalar, C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver
}
}

fn set_problem(&mut self, problem: SolverProblem<C>) {
self.lu = Some(nalgebra::LU::new(problem.f.jacobian(problem.t)));
self.problem = Some(problem);
fn set_linearisation(&mut self, x: &<C as Op>::V, t: <C as Op>::T) {
let problem = self.problem.as_ref().expect("Problem not set");
let matrix = self.matrix.as_mut().expect("Matrix not set");
problem.f.jacobian_inplace(x, t, matrix);
self.lu = Some(matrix.lu());
}

fn set_problem(&mut self, problem: &SolverProblem<C>) {
let linearised_problem = problem.linearise();
let matrix = DMatrix::zeros(
linearised_problem.f.nstates(),
linearised_problem.f.nstates(),
);
self.problem = Some(linearised_problem);
self.matrix = Some(matrix);
}
}
Loading

0 comments on commit 9e99ee8

Please sign in to comment.