Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs #22

Merged
merged 5 commits into from
Apr 3, 2024
Merged

Docs #22

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ state, we can use the following code:

```rust
let mut state = OdeSolverState::new(&problem);
solver.set_problem(&mut state, problem);
solver.set_problem(&mut state, &problem);
while state.t <= t {
solver.step(&mut state).unwrap();
}
Expand Down
52 changes: 50 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,51 @@
//! # DiffSol
//!
//! DiffSol is a library for solving differential equations. It provides a simple interface to solve ODEs and semi-explicit DAEs.
//!
//! ## Getting Started
//!
//! To create a new problem, use the [OdeBuilder] struct. You can set the initial time, initial step size, relative tolerance, absolute tolerance, and parameters,
//! or leave them at their default values. Then, call the [OdeBuilder::build_ode] method with the ODE equations, or the [OdeBuilder::build_ode_with_mass] method
//! with the ODE equations and the mass matrix equations.
//!
//! You will also need to choose a matrix type to use. DiffSol can use the [nalgebra](https://nalgebra.org) `DMatrix` type, or any other type that implements the
//! [Matrix] trait. You can also use the [sundials](https://computation.llnl.gov/projects/sundials) library for the matrix and vector types (see [SundialsMatrix]).
//!
//! To solve the problem, you need to choose a solver. DiffSol provides a pure rust [Bdf] solver, or you can use the [SundialsIda] solver from the sundials library (requires the `sundials` feature).
//! See the [OdeSolverMethod] trait for a more detailed description of the available methods on the solver.
//!
//! ```rust
//! use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod};
//! type M = nalgebra::DMatrix<f64>;
//!
//! let problem = OdeBuilder::new()
//! .rtol(1e-6)
//! .p([0.1])
//! .build_ode::<M, _, _, _>(
//! // dy/dt = -ay
//! |x, p, t, y| {
//! y[0] = -p[0] * x[0];
//! },
//! // Jv = -av
//! |x, p, t, v, y| {
//! y[0] = -p[0] * v[0];
//! },
//! // y(0) = 1
//! |p, t| {
//! nalgebra::DVector::from_vec(vec![1.0])
//! },
//! ).unwrap();
//!
//! let mut solver = Bdf::default();
//! let t = 0.4;
//! let mut state = OdeSolverState::new(&problem);
//! solver.set_problem(&mut state, &problem);
//! while state.t <= t {
//! solver.step(&mut state).unwrap();
//! }
//! let y = solver.interpolate(&state, t);
//! ```

#[cfg(feature = "diffsl-llvm10")]
pub extern crate diffsl10_0 as diffsl;
#[cfg(feature = "diffsl-llvm11")]
Expand Down Expand Up @@ -56,8 +104,8 @@ use matrix::{DenseMatrix, Matrix, MatrixViewMut};
pub use nonlinear_solver::newton::NewtonNonlinearSolver;
use nonlinear_solver::NonLinearSolver;
pub use ode_solver::{
bdf::Bdf, builder::OdeBuilder, equations::OdeEquations, OdeSolverMethod, OdeSolverProblem,
OdeSolverState,
bdf::Bdf, builder::OdeBuilder, equations::OdeEquations, method::OdeSolverMethod,
method::OdeSolverState, problem::OdeSolverProblem,
};
use op::NonLinearOp;
use scalar::{IndexType, Scalar, Scale};
Expand Down
21 changes: 18 additions & 3 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ use serde::Serialize;

use crate::{
matrix::MatrixRef, op::ode::BdfCallable, scalar::scale, DenseMatrix, IndexType, MatrixViewMut,
NewtonNonlinearSolver, NonLinearSolver, Scalar, SolverProblem, Vector, VectorRef, VectorView,
VectorViewMut, LU,
NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
Scalar, SolverProblem, Vector, VectorRef, VectorView, VectorViewMut, LU,
};

use super::{equations::OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState};
use super::equations::OdeEquations;

#[derive(Clone, Debug, Serialize)]
pub struct BdfStatistics<T: Scalar> {
Expand All @@ -39,6 +39,21 @@ impl<T: Scalar> Default for BdfStatistics<T> {
}
}

/// Implements a Backward Difference formula (BDF) implicit multistep integrator.
/// The basic algorithm is derived in \[1\]. This
/// particular implementation follows that implemented in the Matlab routine ode15s
/// described in \[2\] and the SciPy implementation
/// /[3/], which features the NDF formulas for improved
/// stability with associated differences in the error constants, and calculates
/// the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that
/// implemented in the SciPy library \[3\], which also mainly
/// follows \[2\] but uses the more standard Jacobian update.
///
/// # References
///
/// \[1\] Byrne, G. D., & Hindmarsh, A. C. (1975). A polyalgorithm for the numerical solution of ordinary differential equations. ACM Transactions on Mathematical Software (TOMS), 1(1), 71-96.
/// \[2\] Shampine, L. F., & Reichelt, M. W. (1997). The matlab ode suite. SIAM journal on scientific computing, 18(1), 1-22.
/// \[3\] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, T., Cournapeau, D., ... & Van Mulbregt, P. (2020). SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature methods, 17(3), 261-272.
pub struct Bdf<M: DenseMatrix<T = Eqn::T, V = Eqn::V>, Eqn: OdeEquations> {
nonlinear_solver: Box<dyn NonLinearSolver<BdfCallable<Eqn>>>,
ode_problem: Option<OdeSolverProblem<Eqn>>,
Expand Down
133 changes: 133 additions & 0 deletions src/ode_solver/method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use anyhow::Result;
use std::rc::Rc;

use crate::{
op::{filter::FilterCallable, ode_rhs::OdeRhs},
Matrix, NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex,
};

/// Trait for ODE solver methods. This is the main user interface for the ODE solvers.
/// The solver is responsible for stepping the solution (given in the `OdeSolverState`), and interpolating the solution at a given time.
/// However, the solver does not own the state, so the user is responsible for creating and managing the state. If the user
/// wants to change the state, they should call `set_problem` again.
///
/// # Example
///
/// ```
/// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquations };
///
/// fn solve_ode<Eqn: OdeEquations>(solver: &mut impl OdeSolverMethod<Eqn>, problem: &OdeSolverProblem<Eqn>, t: Eqn::T) -> Eqn::V {
/// let mut state = OdeSolverState::new(problem);
/// solver.set_problem(&mut state, problem);
/// while state.t <= t {
/// solver.step(&mut state).unwrap();
/// }
/// solver.interpolate(&state, t)
/// }
/// ```
pub trait OdeSolverMethod<Eqn: OdeEquations> {
/// Get the current problem if it has been set
fn problem(&self) -> Option<&OdeSolverProblem<Eqn>>;

/// Set the problem to solve, this performs any initialisation required by the solver.
/// Call this before calling `step` or `solve`, and call it again if the state is changed manually (i.e. not by the solver)
fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>);

/// Step the solution forward by one step, altering the state in place
fn step(&mut self, state: &mut OdeSolverState<Eqn::M>) -> Result<()>;

/// Interpolate the solution at a given time. This time should be between the current time and the last solver time step
fn interpolate(&self, state: &OdeSolverState<Eqn::M>, t: Eqn::T) -> Eqn::V;

/// Reinitialise the solver state and solve the problem up to time `t`
fn solve(&mut self, problem: &OdeSolverProblem<Eqn>, t: Eqn::T) -> Result<Eqn::V> {
let mut state = OdeSolverState::new(problem);
self.set_problem(&mut state, problem);
while state.t <= t {
self.step(&mut state)?;
}
Ok(self.interpolate(&state, t))
}

/// Reinitialise the solver state making it consistent with the algebraic constraints and solve the problem up to time `t`
fn make_consistent_and_solve<RS: NonLinearSolver<FilterCallable<OdeRhs<Eqn>>>>(
&mut self,
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
root_solver: &mut RS,
) -> Result<Eqn::V> {
let mut state = OdeSolverState::new_consistent(problem, root_solver)?;
self.set_problem(&mut state, problem);
while state.t <= t {
self.step(&mut state)?;
}
Ok(self.interpolate(&state, t))
}
}

/// State for the ODE solver, containing the current solution `y`, the current time `t`, and the current step size `h`.
pub struct OdeSolverState<M: Matrix> {
pub y: M::V,
pub t: M::T,
pub h: M::T,
_phantom: std::marker::PhantomData<M>,
}

impl<M: Matrix> OdeSolverState<M> {
/// Create a new solver state from an ODE problem. Note that this does not make the state consistent with the algebraic constraints.
/// If you need to make the state consistent, use `new_consistent` instead.
pub fn new<Eqn>(ode_problem: &OdeSolverProblem<Eqn>) -> Self
where
Eqn: OdeEquations<M = M, T = M::T, V = M::V>,
{
let t = ode_problem.t0;
let h = ode_problem.h0;
let y = ode_problem.eqn.init(t);
Self {
y,
t,
h,
_phantom: std::marker::PhantomData,
}
}

/// Create a new solver state from an ODE problem, making the state consistent with the algebraic constraints.
pub fn new_consistent<Eqn, S>(
ode_problem: &OdeSolverProblem<Eqn>,
root_solver: &mut S,
) -> Result<Self>
where
Eqn: OdeEquations<M = M, T = M::T, V = M::V>,
S: NonLinearSolver<FilterCallable<OdeRhs<Eqn>>> + ?Sized,
{
let t = ode_problem.t0;
let h = ode_problem.h0;
let indices = ode_problem.eqn.algebraic_indices();
let mut y = ode_problem.eqn.init(t);
if indices.len() == 0 {
return Ok(Self {
y,
t,
h,
_phantom: std::marker::PhantomData,
});
}
let mut y_filtered = y.filter(&indices);
let atol = Rc::new(ode_problem.atol.as_ref().filter(&indices));
let rhs = Rc::new(OdeRhs::new(ode_problem.eqn.clone()));
let f = Rc::new(FilterCallable::new(rhs, &y, indices));
let rtol = ode_problem.rtol;
let init_problem = SolverProblem::new(f, t, atol, rtol);
root_solver.set_problem(init_problem);
root_solver.solve_in_place(&mut y_filtered)?;
let init_problem = root_solver.problem().unwrap();
let indices = init_problem.f.indices();
y.scatter_from(&y_filtered, indices);
Ok(Self {
y,
t,
h,
_phantom: std::marker::PhantomData,
})
}
}
Loading
Loading