diff --git a/book/src/choosing_a_solver.md b/book/src/choosing_a_solver.md index b3499f9..c3f598a 100644 --- a/book/src/choosing_a_solver.md +++ b/book/src/choosing_a_solver.md @@ -13,10 +13,10 @@ Each of these solvers has a number of generic arguments, for example the `Bdf` s In normal use cases, Rust can infer these from your code so you don't need to specify these explicitly. The `Bdf` solver implements the `Default` trait so can be easily created using: ```rust -# use diffsol::{OdeBuilder, OdeSolverState}; +# use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::Bdf; +use diffsol::{Bdf, OdeSolverState, OdeSolverMethod}; # fn main() { # # let problem = OdeBuilder::new() @@ -26,18 +26,19 @@ use diffsol::Bdf; # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), # |_p, _t| DVector::from_element(1, 0.1), # ).unwrap(); -let solver = Bdf::default(); -# let _state = OdeSolverState::new(&problem, &solver).unwrap(); +let mut solver = Bdf::default(); +let state = OdeSolverState::new(&problem, &solver).unwrap(); +solver.set_problem(state, &problem); # } ``` The `Sdirk` solver requires a tableu to be specified so you can use its `new` method to create a new solver, for example using the `tr_bdf2` tableau: ```rust -# use diffsol::{OdeBuilder, OdeSolverState}; +# use diffsol::{OdeBuilder}; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{Sdirk, Tableau, NalgebraLU}; +use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod}; # fn main() { # let problem = OdeBuilder::new() # .p(vec![1.0, 10.0]) @@ -46,8 +47,31 @@ use diffsol::{Sdirk, Tableau, NalgebraLU}; # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), # |_p, _t| DVector::from_element(1, 0.1), # ).unwrap(); -let solver = Sdirk::new(Tableau::::tr_bdf2(), NalgebraLU::default()); -# let _state = OdeSolverState::new(&problem, &solver).unwrap(); +let mut solver = Sdirk::new(Tableau::::tr_bdf2(), NalgebraLU::default()); +let state = OdeSolverState::new(&problem, &solver).unwrap(); +solver.set_problem(state, &problem); # } ``` +You can also use one of the helper functions to create a SDIRK solver with a pre-defined tableau, which will create it with the default linear solver: + +```rust +# use diffsol::{OdeBuilder}; +# use nalgebra::DVector; +# type M = nalgebra::DMatrix; +use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod}; +# fn main() { +# let problem = OdeBuilder::new() +# .p(vec![1.0, 10.0]) +# .build_ode::( +# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), +# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), +# |_p, _t| DVector::from_element(1, 0.1), +# ).unwrap(); +let mut solver = Sdirk::tr_bdf2(); +let state = OdeSolverState::new(&problem, &solver).unwrap(); +solver.set_problem(state, &problem); +# } +``` + + diff --git a/book/src/non_linear_functions.md b/book/src/non_linear_functions.md index 5f0aadb..d9086ea 100644 --- a/book/src/non_linear_functions.md +++ b/book/src/non_linear_functions.md @@ -56,13 +56,13 @@ impl Op for MyProblem { # } ``` -Next we implement the `NonLinearOp` trait for our struct. This trait specifies the functions that will be used to evaluate the rhs function and the jacobian multiplied by a vector. +Next we implement the `NonLinearOp` and `NonLinearOpJacobian` trait for our struct. This trait specifies the functions that will be used to evaluate the rhs function and the jacobian multiplied by a vector. ```rust # fn main() { # use std::rc::Rc; use diffsol::{ - NonLinearOp, OdeSolverEquations, OdeSolverProblem, + NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure }; @@ -96,6 +96,8 @@ impl NonLinearOp for MyProblem { fn call_inplace(&self, x: &V, _t: T, y: &mut V) { y[0] = self.p[0] * x[0] * (1.0 - x[0] / self.p[1]); } +} +impl NonLinearOpJacobian for MyProblem { fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { y[0] = self.p[0] * v[0] * (1.0 - 2.0 * x[0] / self.p[1]); } diff --git a/book/src/putting_it_all_together.md b/book/src/putting_it_all_together.md index 4f22bac..cdc3e1a 100644 --- a/book/src/putting_it_all_together.md +++ b/book/src/putting_it_all_together.md @@ -33,13 +33,13 @@ let out: Option>> = None; ## Creating the equations Now we have variables `rhs` and `init` that are structs implementing the required traits, and `mass`, `root`, and `out` set to `None`. Using these, we can create the `OdeSolverEquations` struct, -and then provide it to the `OdeSolverProblem` struct to create the problem. +and then provide it to the `OdeBuilder` struct to create the problem. ```rust # fn main() { # use std::rc::Rc; -# use diffsol::{NonLinearOp, OdeSolverProblem, Op, UnitCallable, ConstantClosure}; -use diffsol::OdeSolverEquations; +# use diffsol::{NonLinearOp, NonLinearOpJacobian, OdeSolverProblem, Op, UnitCallable, ConstantClosure}; +use diffsol::{OdeSolverEquations, OdeBuilder}; # type T = f64; # type V = nalgebra::DVector; @@ -71,6 +71,8 @@ use diffsol::OdeSolverEquations; # fn call_inplace(&self, x: &V, _t: T, y: &mut V) { # y[0] = self.p[0] * x[0] * (1.0 - x[0] / self.p[1]); # } +# } +# impl NonLinearOpJacobian for MyProblem { # fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { # y[0] = self.p[0] * v[0] * (1.0 - 2.0 * x[0] / self.p[1]); # } @@ -92,15 +94,7 @@ use diffsol::OdeSolverEquations; # # let p = Rc::new(V::zeros(0)); let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); -let rtol = 1e-6; -let atol = V::from_element(1, 1e-6); -let t0 = 0.0; -let h0 = 1.0; -let with_sensitivity = false; -let sens_error_control = false; -let _problem = OdeSolverProblem::new( - eqn, rtol, atol, t0, h0, with_sensitivity, sens_error_control -).unwrap(); +let _problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); # } ``` diff --git a/book/src/solving_the_problem.md b/book/src/solving_the_problem.md index c317d43..603314c 100644 --- a/book/src/solving_the_problem.md +++ b/book/src/solving_the_problem.md @@ -97,7 +97,7 @@ let _soln = &solver.state().unwrap().y; # } ``` -DiffSol also has two convenience functions `solve` and `solve_dense` on the `OdeSolverMethod` trait. `solve` will initialise the problem and solve the problem up to a specified time, returning the solution at all the +DiffSol also has two convenience functions `solve` and `solve_dense` on the `OdeSolverMethod` trait. `solve` solve the problem from an initial state up to a specified time, returning the solution at all the internal timesteps used by the solver. This function returns a tuple that contains a `Vec` of the solution at each timestep, and a `Vec` of the times at each timestep. @@ -105,7 +105,7 @@ the solution at each timestep, and a `Vec` of the times at each timestep. # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, Bdf}; +use diffsol::{OdeSolverMethod, Bdf, OdeSolverState}; # fn main() { # let problem = OdeBuilder::new() @@ -116,17 +116,18 @@ use diffsol::{OdeSolverMethod, Bdf}; # |_p, _t| DVector::from_element(1, 0.1), # ).unwrap(); let mut solver = Bdf::default(); -let (ys, ts) = solver.solve(&problem, 10.0).unwrap(); +let state = OdeSolverState::new(&problem, &solver).unwrap(); +let (ys, ts) = solver.solve(&problem, state, 10.0).unwrap(); # } ``` -`solve_dense` will initialise the problem and solve the problem, returning the solution at a `Vec` of times provided by the user. This function returns a `Vec`, where `V` is the vector type used to define the problem. +`solve_dense` will solve a problem from an initial state, returning the solution at a `Vec` of times provided by the user. This function returns a `Vec`, where `V` is the vector type used to define the problem. ```rust # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, Bdf}; +use diffsol::{OdeSolverMethod, Bdf, OdeSolverState}; # fn main() { # let problem = OdeBuilder::new() @@ -138,6 +139,7 @@ use diffsol::{OdeSolverMethod, Bdf}; # ).unwrap(); let mut solver = Bdf::default(); let times = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; -let _soln = solver.solve_dense(&problem, ×).unwrap(); +let state = OdeSolverState::new(&problem, &solver).unwrap(); +let _soln = solver.solve_dense(&problem, state, ×).unwrap(); # } ``` \ No newline at end of file diff --git a/book/src/sparse_problems.md b/book/src/sparse_problems.md index 540441f..d476b24 100644 --- a/book/src/sparse_problems.md +++ b/book/src/sparse_problems.md @@ -40,7 +40,7 @@ To illustrate this, we can calculate the jacobian matrix from the `rhs` function ```rust # use diffsol::OdeBuilder; -use diffsol::{OdeEquations, NonLinearOp, Matrix, ConstantOp}; +use diffsol::{OdeEquations, NonLinearOp, NonLinearOpJacobian, Matrix, ConstantOp}; # type M = diffsol::SparseColMat; # type V = faer::Col; @@ -101,7 +101,7 @@ This is described in more detail in the ["Custom Problem Structs"](./custom_prob # fn main() { use std::rc::Rc; use faer::sparse::{SparseColMat, SymbolicSparseColMatRef}; -use diffsol::{NonLinearOp, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure}; +use diffsol::{NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure, OdeBuilder}; type T = f64; type V = faer::Col; @@ -144,7 +144,10 @@ impl NonLinearOp for MyProblem { y[i] = self.p[0] * x[i] * (1.0 - x[i] / self.p[1]); } } - fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { + +} +impl NonLinearOpJacobian for MyProblem { + fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { for i in 0..10 { y[i] = self.p[0] * v[i] * (1.0 - 2.0 * x[i] / self.p[1]); } @@ -155,6 +158,7 @@ impl NonLinearOp for MyProblem { y.faer_mut().values_mut()[i] = self.p[0] * (1.0 - 2.0 * x[row] / self.p[1]); } } + } let p = [1.0, 10.0]; @@ -173,11 +177,7 @@ let out: Option>> = None; let p = Rc::new(V::zeros(0)); let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); -let rtol = 1e-6; -let atol = V::from_fn(10, |_| 1e-6); -let t0 = 0.0; -let h0 = 1.0; -let _problem = OdeSolverProblem::new(eqn, rtol, atol, t0, h0, false, false).unwrap(); +let _problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); # } ``` diff --git a/src/lib.rs b/src/lib.rs index cf844e9..43fc089 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -202,8 +202,8 @@ pub use ode_solver::{ sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, sens_equations::SensEquations, sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau, }; -use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint}; -use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose}; +pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint}; +pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose}; pub use op::nonlinear_op::{ NonLinearOp, NonLinearOpAdjoint, NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, };