Skip to content

Commit

Permalink
docs: update book to new api (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins authored Oct 22, 2024
1 parent d5fee3a commit 16a2f4b
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 38 deletions.
40 changes: 32 additions & 8 deletions book/src/choosing_a_solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ Each of these solvers has a number of generic arguments, for example the `Bdf` s
In normal use cases, Rust can infer these from your code so you don't need to specify these explicitly. The `Bdf` solver implements the `Default` trait so can be easily created using:

```rust
# use diffsol::{OdeBuilder, OdeSolverState};
# use diffsol::OdeBuilder;
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::Bdf;
use diffsol::{Bdf, OdeSolverState, OdeSolverMethod};
# fn main() {
#
# let problem = OdeBuilder::new()
Expand All @@ -26,18 +26,19 @@ use diffsol::Bdf;
# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]),
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let solver = Bdf::default();
# let _state = OdeSolverState::new(&problem, &solver).unwrap();
let mut solver = Bdf::default();
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
# }
```

The `Sdirk` solver requires a tableu to be specified so you can use its `new` method to create a new solver, for example using the `tr_bdf2` tableau:

```rust
# use diffsol::{OdeBuilder, OdeSolverState};
# use diffsol::{OdeBuilder};
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::{Sdirk, Tableau, NalgebraLU};
use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod};
# fn main() {
# let problem = OdeBuilder::new()
# .p(vec![1.0, 10.0])
Expand All @@ -46,8 +47,31 @@ use diffsol::{Sdirk, Tableau, NalgebraLU};
# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]),
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let solver = Sdirk::new(Tableau::<M>::tr_bdf2(), NalgebraLU::default());
# let _state = OdeSolverState::new(&problem, &solver).unwrap();
let mut solver = Sdirk::new(Tableau::<M>::tr_bdf2(), NalgebraLU::default());
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
# }
```

You can also use one of the helper functions to create a SDIRK solver with a pre-defined tableau, which will create it with the default linear solver:

```rust
# use diffsol::{OdeBuilder};
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod};
# fn main() {
# let problem = OdeBuilder::new()
# .p(vec![1.0, 10.0])
# .build_ode::<M, _, _, _>(
# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]),
# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]),
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let mut solver = Sdirk::tr_bdf2();
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
# }
```


6 changes: 4 additions & 2 deletions book/src/non_linear_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ impl Op for MyProblem {
# }
```

Next we implement the `NonLinearOp` trait for our struct. This trait specifies the functions that will be used to evaluate the rhs function and the jacobian multiplied by a vector.
Next we implement the `NonLinearOp` and `NonLinearOpJacobian` trait for our struct. This trait specifies the functions that will be used to evaluate the rhs function and the jacobian multiplied by a vector.

```rust
# fn main() {
# use std::rc::Rc;
use diffsol::{
NonLinearOp, OdeSolverEquations, OdeSolverProblem,
NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem,
Op, UnitCallable, ConstantClosure
};

Expand Down Expand Up @@ -96,6 +96,8 @@ impl NonLinearOp for MyProblem {
fn call_inplace(&self, x: &V, _t: T, y: &mut V) {
y[0] = self.p[0] * x[0] * (1.0 - x[0] / self.p[1]);
}
}
impl NonLinearOpJacobian for MyProblem {
fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) {
y[0] = self.p[0] * v[0] * (1.0 - 2.0 * x[0] / self.p[1]);
}
Expand Down
18 changes: 6 additions & 12 deletions book/src/putting_it_all_together.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ let out: Option<Rc<UnitCallable<M>>> = None;
## Creating the equations

Now we have variables `rhs` and `init` that are structs implementing the required traits, and `mass`, `root`, and `out` set to `None`. Using these, we can create the `OdeSolverEquations` struct,
and then provide it to the `OdeSolverProblem` struct to create the problem.
and then provide it to the `OdeBuilder` struct to create the problem.

```rust
# fn main() {
# use std::rc::Rc;
# use diffsol::{NonLinearOp, OdeSolverProblem, Op, UnitCallable, ConstantClosure};
use diffsol::OdeSolverEquations;
# use diffsol::{NonLinearOp, NonLinearOpJacobian, OdeSolverProblem, Op, UnitCallable, ConstantClosure};
use diffsol::{OdeSolverEquations, OdeBuilder};

# type T = f64;
# type V = nalgebra::DVector<T>;
Expand Down Expand Up @@ -71,6 +71,8 @@ use diffsol::OdeSolverEquations;
# fn call_inplace(&self, x: &V, _t: T, y: &mut V) {
# y[0] = self.p[0] * x[0] * (1.0 - x[0] / self.p[1]);
# }
# }
# impl NonLinearOpJacobian for MyProblem {
# fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) {
# y[0] = self.p[0] * v[0] * (1.0 - 2.0 * x[0] / self.p[1]);
# }
Expand All @@ -92,15 +94,7 @@ use diffsol::OdeSolverEquations;
#
# let p = Rc::new(V::zeros(0));
let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone());
let rtol = 1e-6;
let atol = V::from_element(1, 1e-6);
let t0 = 0.0;
let h0 = 1.0;
let with_sensitivity = false;
let sens_error_control = false;
let _problem = OdeSolverProblem::new(
eqn, rtol, atol, t0, h0, with_sensitivity, sens_error_control
).unwrap();
let _problem = OdeBuilder::new().build_from_eqn(eqn).unwrap();
# }
```

Expand Down
14 changes: 8 additions & 6 deletions book/src/solving_the_problem.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ let _soln = &solver.state().unwrap().y;
# }
```

DiffSol also has two convenience functions `solve` and `solve_dense` on the `OdeSolverMethod` trait. `solve` will initialise the problem and solve the problem up to a specified time, returning the solution at all the
DiffSol also has two convenience functions `solve` and `solve_dense` on the `OdeSolverMethod` trait. `solve` solve the problem from an initial state up to a specified time, returning the solution at all the
internal timesteps used by the solver. This function returns a tuple that contains a `Vec` of
the solution at each timestep, and a `Vec` of the times at each timestep.

```rust
# use diffsol::OdeBuilder;
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::{OdeSolverMethod, Bdf};
use diffsol::{OdeSolverMethod, Bdf, OdeSolverState};

# fn main() {
# let problem = OdeBuilder::new()
Expand All @@ -116,17 +116,18 @@ use diffsol::{OdeSolverMethod, Bdf};
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let mut solver = Bdf::default();
let (ys, ts) = solver.solve(&problem, 10.0).unwrap();
let state = OdeSolverState::new(&problem, &solver).unwrap();
let (ys, ts) = solver.solve(&problem, state, 10.0).unwrap();
# }
```

`solve_dense` will initialise the problem and solve the problem, returning the solution at a `Vec` of times provided by the user. This function returns a `Vec<V>`, where `V` is the vector type used to define the problem.
`solve_dense` will solve a problem from an initial state, returning the solution at a `Vec` of times provided by the user. This function returns a `Vec<V>`, where `V` is the vector type used to define the problem.

```rust
# use diffsol::OdeBuilder;
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::{OdeSolverMethod, Bdf};
use diffsol::{OdeSolverMethod, Bdf, OdeSolverState};

# fn main() {
# let problem = OdeBuilder::new()
Expand All @@ -138,6 +139,7 @@ use diffsol::{OdeSolverMethod, Bdf};
# ).unwrap();
let mut solver = Bdf::default();
let times = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let _soln = solver.solve_dense(&problem, &times).unwrap();
let state = OdeSolverState::new(&problem, &solver).unwrap();
let _soln = solver.solve_dense(&problem, state, &times).unwrap();
# }
```
16 changes: 8 additions & 8 deletions book/src/sparse_problems.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ To illustrate this, we can calculate the jacobian matrix from the `rhs` function

```rust
# use diffsol::OdeBuilder;
use diffsol::{OdeEquations, NonLinearOp, Matrix, ConstantOp};
use diffsol::{OdeEquations, NonLinearOp, NonLinearOpJacobian, Matrix, ConstantOp};

# type M = diffsol::SparseColMat<f64>;
# type V = faer::Col<f64>;
Expand Down Expand Up @@ -101,7 +101,7 @@ This is described in more detail in the ["Custom Problem Structs"](./custom_prob
# fn main() {
use std::rc::Rc;
use faer::sparse::{SparseColMat, SymbolicSparseColMatRef};
use diffsol::{NonLinearOp, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure};
use diffsol::{NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure, OdeBuilder};

type T = f64;
type V = faer::Col<T>;
Expand Down Expand Up @@ -144,7 +144,10 @@ impl NonLinearOp for MyProblem {
y[i] = self.p[0] * x[i] * (1.0 - x[i] / self.p[1]);
}
}
fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) {

}
impl NonLinearOpJacobian for MyProblem {
fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) {
for i in 0..10 {
y[i] = self.p[0] * v[i] * (1.0 - 2.0 * x[i] / self.p[1]);
}
Expand All @@ -155,6 +158,7 @@ impl NonLinearOp for MyProblem {
y.faer_mut().values_mut()[i] = self.p[0] * (1.0 - 2.0 * x[row] / self.p[1]);
}
}

}

let p = [1.0, 10.0];
Expand All @@ -173,11 +177,7 @@ let out: Option<Rc<UnitCallable<M>>> = None;

let p = Rc::new(V::zeros(0));
let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone());
let rtol = 1e-6;
let atol = V::from_fn(10, |_| 1e-6);
let t0 = 0.0;
let h0 = 1.0;
let _problem = OdeSolverProblem::new(eqn, rtol, atol, t0, h0, false, false).unwrap();
let _problem = OdeBuilder::new().build_from_eqn(eqn).unwrap();
# }
```

4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ pub use ode_solver::{
sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, sens_equations::SensEquations,
sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau,
};
use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose};
pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose};
pub use op::nonlinear_op::{
NonLinearOp, NonLinearOpAdjoint, NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint,
};
Expand Down

0 comments on commit 16a2f4b

Please sign in to comment.