Skip to content

Commit

Permalink
feat: add helper sensitivities and adjoint functions to method trait (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins authored Oct 21, 2024
1 parent 29a29ae commit d5fee3a
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 15 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@ DiffSol implements the following solvers:
- A variable order Backwards Difference Formulae (BDF) solver, suitable for stiff problems and singular mass matrices.
- A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver, suitable for moderately stiff problems and singular mass matrices. You can use your own butcher tableau or use one of the provided (`tr_bdf2` or `esdirk34`).

All solvers feature adaptive step-size control to given tolerances, dense output, event handling, stepping to specific times and forward sensitivity analysis.
All solvers feature:
- adaptive step-size control to given tolerances,
- dense output,
- event handling,
- stepping to specific times,
- numerical quadrature of an optional output function over time
- forward sensitivity analysis,
- backwards or adjoint sensitivity analysis,

For comparison, the BDF solvers are similar to MATLAB's `ode15s` solver, the `bdf` solver in SciPy's `solve_ivp` function, or the BDF solver in SUNDIALS.
The ESDIRK solver using the provided `tr_bdf2` tableau is similar to MATLAB's `ode23t` solver.

Expand Down
236 changes: 222 additions & 14 deletions src/ode_solver/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ where
return Err(ode_solver_error!(InvalidTEval));
}

let mut write_out = |i: usize, y: &Eqn::V, g: Option<&Eqn::V>| {
let mut write_out = |i: usize, y: Option<&Eqn::V>, g: Option<&Eqn::V>| {
let mut y_out = ret.column_mut(i);
if let Some(g) = g {
y_out.copy_from(g);
} else {
} else if let Some(y) = y {
match problem.eqn.out() {
Some(out) => y_out.copy_from(&out.call(y, t_eval[i])),
None => y_out.copy_from(y),
Expand All @@ -213,12 +213,12 @@ where
while self.state().unwrap().t < *t {
step_reason = self.step()?;
}
let y = self.interpolate(*t)?;
if problem.integrate_out {
let g = self.interpolate_out(*t)?;
write_out(i, &y, Some(&g));
write_out(i, None, Some(&g));
} else {
write_out(i, &y, None);
let y = self.interpolate(*t)?;
write_out(i, Some(&y), None);
}
}

Expand All @@ -227,16 +227,171 @@ where
step_reason = self.step()?;
}
if problem.integrate_out {
write_out(
t_eval.len() - 1,
self.state().unwrap().y,
Some(self.state().unwrap().g),
);
write_out(t_eval.len() - 1, None, Some(self.state().unwrap().g));
} else {
write_out(t_eval.len() - 1, self.state().unwrap().y, None);
write_out(t_eval.len() - 1, Some(self.state().unwrap().y), None);
}
Ok(ret)
}

/// Using the provided state, solve the forwards and adjoint problem from the current time up to `final_time`.
/// An output function must be provided and the problem must be setup to integrate this output
/// function over time. Returns a tuple of `(g, sgs)`, where `g` is the vector of the integral
/// of the output function from the current time to `final_time`, and `sgs` is a `Vec` where
/// the ith element is the sensitivities of the ith element of `g` with respect to the
/// parameters.
#[allow(clippy::type_complexity)]
fn solve_adjoint(
mut self,
problem: &OdeSolverProblem<Eqn>,
state: Self::State,
final_time: Eqn::T,
max_steps_between_checkpoints: Option<usize>,
) -> Result<(Eqn::V, Vec<Eqn::V>), DiffsolError>
where
Self: AdjointOdeSolverMethod<Eqn>,
Eqn: OdeEquationsAdjoint,
Eqn::M: DefaultSolver,
Eqn::V: DefaultDenseMatrix,
Self: Sized,
{
if problem.eqn.out().is_none() {
return Err(ode_solver_error!(
Other,
"Cannot solve adjoint without output function"
));
}
if !problem.integrate_out {
return Err(ode_solver_error!(
Other,
"Cannot solve adjoint without integrating out"
));
}
let max_steps_between_checkpoints = max_steps_between_checkpoints.unwrap_or(500);
self.set_problem(state, problem)?;
let t0 = self.state().unwrap().t;
let mut ts = vec![t0];
let mut ys = vec![self.state().unwrap().y.clone()];
let mut ydots = vec![self.state().unwrap().dy.clone()];

// do the main forward solve, saving checkpoints
self.set_stop_time(final_time)?;
let mut nsteps = 0;
let mut checkpoints = vec![self.checkpoint().unwrap()];
while self.step()? != OdeSolverStopReason::TstopReached {
ts.push(self.state().unwrap().t);
ys.push(self.state().unwrap().y.clone());
ydots.push(self.state().unwrap().dy.clone());
nsteps += 1;
if nsteps > max_steps_between_checkpoints {
checkpoints.push(self.checkpoint().unwrap());
nsteps = 0;
ts.clear();
ys.clear();
ydots.clear();
}
}
ts.push(self.state().unwrap().t);
ys.push(self.state().unwrap().y.clone());
ydots.push(self.state().unwrap().dy.clone());
checkpoints.push(self.checkpoint().unwrap());

// save integrateed out function
let g = self.state().unwrap().g.clone();

// construct the adjoint solver
let last_segment = HermiteInterpolator::new(ys, ydots, ts);
let mut adjoint_solver = self.into_adjoint_solver(checkpoints, last_segment)?;

// solve the adjoint problem
adjoint_solver.set_stop_time(t0).unwrap();
while adjoint_solver.step()? != OdeSolverStopReason::TstopReached {}

// correct the adjoint solution for the initial conditions
let mut state = adjoint_solver.take_state().unwrap();
let state_mut = state.as_mut();
adjoint_solver
.problem()
.unwrap()
.eqn
.correct_sg_for_init(t0, state_mut.s, state_mut.sg);

// return the solution
Ok((g, state_mut.sg.to_owned()))
}

/// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`
/// Returns a tuple `(y, sens)`, where `y` is a dense matrix of solution values at timepoints given by `t_eval`,
/// and `sens` is a Vec of dense matrices, the ith element of the Vec are the the sensitivities with respect to the ith parameter.
/// After the solver has finished, the internal state of the solver is at time `t_eval[t_eval.len()-1]`.
#[allow(clippy::type_complexity)]
fn solve_dense_sensitivities(
&mut self,
problem: &OdeSolverProblem<Eqn>,
state: Self::State,
t_eval: &[Eqn::T],
) -> Result<
(
<Eqn::V as DefaultDenseMatrix>::M,
Vec<<Eqn::V as DefaultDenseMatrix>::M>,
),
DiffsolError,
>
where
Self: SensitivitiesOdeSolverMethod<Eqn>,
Eqn: OdeEquationsSens,
Eqn::M: DefaultSolver,
Eqn::V: DefaultDenseMatrix,
Self: Sized,
{
if problem.integrate_out {
return Err(ode_solver_error!(
Other,
"Cannot integrate out when solving for sensitivities"
));
}
self.set_problem_with_sensitivities(state, problem)?;
let nrows = problem.eqn.rhs().nstates();
let mut ret = <<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nrows, t_eval.len());
let mut ret_sens =
vec![
<<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nrows, t_eval.len());
problem.eqn.rhs().nparams()
];

// check t_eval is increasing and all values are greater than or equal to the current time
let t0 = self.state().unwrap().t;
if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) {
return Err(ode_solver_error!(InvalidTEval));
}

// do loop
self.set_stop_time(t_eval[t_eval.len() - 1])?;
let mut step_reason = OdeSolverStopReason::InternalTimestep;
for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() {
while self.state().unwrap().t < *t {
step_reason = self.step()?;
}
let y = self.interpolate(*t)?;
ret.column_mut(i).copy_from(&y);
let s = self.interpolate_sens(*t)?;
for (j, s_j) in s.iter().enumerate() {
ret_sens[j].column_mut(i).copy_from(s_j);
}
}

// do final step
while step_reason != OdeSolverStopReason::TstopReached {
step_reason = self.step()?;
}
let y = self.state().unwrap().y;
ret.column_mut(t_eval.len() - 1).copy_from(y);
let s = self.state().unwrap().s;
for (j, s_j) in s.iter().enumerate() {
ret_sens[j].column_mut(t_eval.len() - 1).copy_from(s_j);
}
Ok((ret, ret_sens))
}
}

pub trait AugmentedOdeSolverMethod<Eqn, AugmentedEqn>: OdeSolverMethod<Eqn>
Expand Down Expand Up @@ -337,9 +492,11 @@ where
#[cfg(test)]
mod test {
use crate::{
ode_solver::test_models::exponential_decay::exponential_decay_problem,
ode_solver::test_models::exponential_decay::exponential_decay_problem_adjoint, scale, Bdf,
OdeSolverMethod, OdeSolverState, Vector,
ode_solver::test_models::exponential_decay::{
exponential_decay_problem, exponential_decay_problem_adjoint,
exponential_decay_problem_sens,
},
scale, Bdf, OdeSolverMethod, OdeSolverState, Vector,
};

#[test]
Expand Down Expand Up @@ -410,4 +567,55 @@ mod test {
y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0);
}
}

#[test]
fn test_dense_solve_sensitivities() {
let mut s = Bdf::with_sensitivities();
let (problem, soln) = exponential_decay_problem_sens::<nalgebra::DMatrix<f64>>(false);

let state = OdeSolverState::new_with_sensitivities(&problem, &s).unwrap();
let t_eval = soln.solution_points.iter().map(|p| p.t).collect::<Vec<_>>();
let (y, sens) = s
.solve_dense_sensitivities(&problem, state, t_eval.as_slice())
.unwrap();
for (i, soln_pt) in soln.solution_points.iter().enumerate() {
let y_i = y.column(i).into_owned();
y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0);
}
for (j, soln_pts) in soln.sens_solution_points.unwrap().iter().enumerate() {
for (i, soln_pt) in soln_pts.iter().enumerate() {
let sens_i = sens[j].column(i).into_owned();
sens_i.assert_eq_norm(
&soln_pt.state,
problem.sens_atol.as_ref().unwrap(),
problem.sens_rtol.unwrap(),
15.0,
);
}
}
}

#[test]
fn test_solve_adjoint() {
let s = Bdf::default();
let (problem, soln) = exponential_decay_problem_adjoint::<nalgebra::DMatrix<f64>>();

let state = OdeSolverState::new(&problem, &s).unwrap();
let final_time = soln.solution_points[soln.solution_points.len() - 1].t;
let (g, gs_adj) = s.solve_adjoint(&problem, state, final_time, None).unwrap();
g.assert_eq_norm(
&soln.solution_points[soln.solution_points.len() - 1].state,
problem.out_atol.as_ref().unwrap(),
problem.out_rtol.unwrap(),
15.0,
);
for (j, soln_pts) in soln.sens_solution_points.unwrap().iter().enumerate() {
gs_adj[j].assert_eq_norm(
&soln_pts[0].state,
problem.out_atol.as_ref().unwrap(),
problem.out_rtol.unwrap(),
15.0,
);
}
}
}

0 comments on commit d5fee3a

Please sign in to comment.