diff --git a/README.md b/README.md index b073d66..81eb38e 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,15 @@ DiffSol implements the following solvers: - A variable order Backwards Difference Formulae (BDF) solver, suitable for stiff problems and singular mass matrices. - A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver, suitable for moderately stiff problems and singular mass matrices. You can use your own butcher tableau or use one of the provided (`tr_bdf2` or `esdirk34`). -All solvers feature adaptive step-size control to given tolerances, dense output, event handling, stepping to specific times and forward sensitivity analysis. +All solvers feature: +- adaptive step-size control to given tolerances, +- dense output, +- event handling, +- stepping to specific times, +- numerical quadrature of an optional output function over time +- forward sensitivity analysis, +- backwards or adjoint sensitivity analysis, + For comparison, the BDF solvers are similar to MATLAB's `ode15s` solver, the `bdf` solver in SciPy's `solve_ivp` function, or the BDF solver in SUNDIALS. The ESDIRK solver using the provided `tr_bdf2` tableau is similar to MATLAB's `ode23t` solver. diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index a542e3c..7b231fc 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -194,11 +194,11 @@ where return Err(ode_solver_error!(InvalidTEval)); } - let mut write_out = |i: usize, y: &Eqn::V, g: Option<&Eqn::V>| { + let mut write_out = |i: usize, y: Option<&Eqn::V>, g: Option<&Eqn::V>| { let mut y_out = ret.column_mut(i); if let Some(g) = g { y_out.copy_from(g); - } else { + } else if let Some(y) = y { match problem.eqn.out() { Some(out) => y_out.copy_from(&out.call(y, t_eval[i])), None => y_out.copy_from(y), @@ -213,12 +213,12 @@ where while self.state().unwrap().t < *t { step_reason = self.step()?; } - let y = self.interpolate(*t)?; if problem.integrate_out { let g = self.interpolate_out(*t)?; - write_out(i, &y, Some(&g)); + write_out(i, None, Some(&g)); } else { - write_out(i, &y, None); + let y = self.interpolate(*t)?; + write_out(i, Some(&y), None); } } @@ -227,16 +227,171 @@ where step_reason = self.step()?; } if problem.integrate_out { - write_out( - t_eval.len() - 1, - self.state().unwrap().y, - Some(self.state().unwrap().g), - ); + write_out(t_eval.len() - 1, None, Some(self.state().unwrap().g)); } else { - write_out(t_eval.len() - 1, self.state().unwrap().y, None); + write_out(t_eval.len() - 1, Some(self.state().unwrap().y), None); } Ok(ret) } + + /// Using the provided state, solve the forwards and adjoint problem from the current time up to `final_time`. + /// An output function must be provided and the problem must be setup to integrate this output + /// function over time. Returns a tuple of `(g, sgs)`, where `g` is the vector of the integral + /// of the output function from the current time to `final_time`, and `sgs` is a `Vec` where + /// the ith element is the sensitivities of the ith element of `g` with respect to the + /// parameters. + #[allow(clippy::type_complexity)] + fn solve_adjoint( + mut self, + problem: &OdeSolverProblem, + state: Self::State, + final_time: Eqn::T, + max_steps_between_checkpoints: Option, + ) -> Result<(Eqn::V, Vec), DiffsolError> + where + Self: AdjointOdeSolverMethod, + Eqn: OdeEquationsAdjoint, + Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, + Self: Sized, + { + if problem.eqn.out().is_none() { + return Err(ode_solver_error!( + Other, + "Cannot solve adjoint without output function" + )); + } + if !problem.integrate_out { + return Err(ode_solver_error!( + Other, + "Cannot solve adjoint without integrating out" + )); + } + let max_steps_between_checkpoints = max_steps_between_checkpoints.unwrap_or(500); + self.set_problem(state, problem)?; + let t0 = self.state().unwrap().t; + let mut ts = vec![t0]; + let mut ys = vec![self.state().unwrap().y.clone()]; + let mut ydots = vec![self.state().unwrap().dy.clone()]; + + // do the main forward solve, saving checkpoints + self.set_stop_time(final_time)?; + let mut nsteps = 0; + let mut checkpoints = vec![self.checkpoint().unwrap()]; + while self.step()? != OdeSolverStopReason::TstopReached { + ts.push(self.state().unwrap().t); + ys.push(self.state().unwrap().y.clone()); + ydots.push(self.state().unwrap().dy.clone()); + nsteps += 1; + if nsteps > max_steps_between_checkpoints { + checkpoints.push(self.checkpoint().unwrap()); + nsteps = 0; + ts.clear(); + ys.clear(); + ydots.clear(); + } + } + ts.push(self.state().unwrap().t); + ys.push(self.state().unwrap().y.clone()); + ydots.push(self.state().unwrap().dy.clone()); + checkpoints.push(self.checkpoint().unwrap()); + + // save integrateed out function + let g = self.state().unwrap().g.clone(); + + // construct the adjoint solver + let last_segment = HermiteInterpolator::new(ys, ydots, ts); + let mut adjoint_solver = self.into_adjoint_solver(checkpoints, last_segment)?; + + // solve the adjoint problem + adjoint_solver.set_stop_time(t0).unwrap(); + while adjoint_solver.step()? != OdeSolverStopReason::TstopReached {} + + // correct the adjoint solution for the initial conditions + let mut state = adjoint_solver.take_state().unwrap(); + let state_mut = state.as_mut(); + adjoint_solver + .problem() + .unwrap() + .eqn + .correct_sg_for_init(t0, state_mut.s, state_mut.sg); + + // return the solution + Ok((g, state_mut.sg.to_owned())) + } + + /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]` + /// Returns a tuple `(y, sens)`, where `y` is a dense matrix of solution values at timepoints given by `t_eval`, + /// and `sens` is a Vec of dense matrices, the ith element of the Vec are the the sensitivities with respect to the ith parameter. + /// After the solver has finished, the internal state of the solver is at time `t_eval[t_eval.len()-1]`. + #[allow(clippy::type_complexity)] + fn solve_dense_sensitivities( + &mut self, + problem: &OdeSolverProblem, + state: Self::State, + t_eval: &[Eqn::T], + ) -> Result< + ( + ::M, + Vec<::M>, + ), + DiffsolError, + > + where + Self: SensitivitiesOdeSolverMethod, + Eqn: OdeEquationsSens, + Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, + Self: Sized, + { + if problem.integrate_out { + return Err(ode_solver_error!( + Other, + "Cannot integrate out when solving for sensitivities" + )); + } + self.set_problem_with_sensitivities(state, problem)?; + let nrows = problem.eqn.rhs().nstates(); + let mut ret = <::M as Matrix>::zeros(nrows, t_eval.len()); + let mut ret_sens = + vec![ + <::M as Matrix>::zeros(nrows, t_eval.len()); + problem.eqn.rhs().nparams() + ]; + + // check t_eval is increasing and all values are greater than or equal to the current time + let t0 = self.state().unwrap().t; + if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) { + return Err(ode_solver_error!(InvalidTEval)); + } + + // do loop + self.set_stop_time(t_eval[t_eval.len() - 1])?; + let mut step_reason = OdeSolverStopReason::InternalTimestep; + for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() { + while self.state().unwrap().t < *t { + step_reason = self.step()?; + } + let y = self.interpolate(*t)?; + ret.column_mut(i).copy_from(&y); + let s = self.interpolate_sens(*t)?; + for (j, s_j) in s.iter().enumerate() { + ret_sens[j].column_mut(i).copy_from(s_j); + } + } + + // do final step + while step_reason != OdeSolverStopReason::TstopReached { + step_reason = self.step()?; + } + let y = self.state().unwrap().y; + ret.column_mut(t_eval.len() - 1).copy_from(y); + let s = self.state().unwrap().s; + for (j, s_j) in s.iter().enumerate() { + ret_sens[j].column_mut(t_eval.len() - 1).copy_from(s_j); + } + Ok((ret, ret_sens)) + } } pub trait AugmentedOdeSolverMethod: OdeSolverMethod @@ -337,9 +492,11 @@ where #[cfg(test)] mod test { use crate::{ - ode_solver::test_models::exponential_decay::exponential_decay_problem, - ode_solver::test_models::exponential_decay::exponential_decay_problem_adjoint, scale, Bdf, - OdeSolverMethod, OdeSolverState, Vector, + ode_solver::test_models::exponential_decay::{ + exponential_decay_problem, exponential_decay_problem_adjoint, + exponential_decay_problem_sens, + }, + scale, Bdf, OdeSolverMethod, OdeSolverState, Vector, }; #[test] @@ -410,4 +567,55 @@ mod test { y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); } } + + #[test] + fn test_dense_solve_sensitivities() { + let mut s = Bdf::with_sensitivities(); + let (problem, soln) = exponential_decay_problem_sens::>(false); + + let state = OdeSolverState::new_with_sensitivities(&problem, &s).unwrap(); + let t_eval = soln.solution_points.iter().map(|p| p.t).collect::>(); + let (y, sens) = s + .solve_dense_sensitivities(&problem, state, t_eval.as_slice()) + .unwrap(); + for (i, soln_pt) in soln.solution_points.iter().enumerate() { + let y_i = y.column(i).into_owned(); + y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); + } + for (j, soln_pts) in soln.sens_solution_points.unwrap().iter().enumerate() { + for (i, soln_pt) in soln_pts.iter().enumerate() { + let sens_i = sens[j].column(i).into_owned(); + sens_i.assert_eq_norm( + &soln_pt.state, + problem.sens_atol.as_ref().unwrap(), + problem.sens_rtol.unwrap(), + 15.0, + ); + } + } + } + + #[test] + fn test_solve_adjoint() { + let s = Bdf::default(); + let (problem, soln) = exponential_decay_problem_adjoint::>(); + + let state = OdeSolverState::new(&problem, &s).unwrap(); + let final_time = soln.solution_points[soln.solution_points.len() - 1].t; + let (g, gs_adj) = s.solve_adjoint(&problem, state, final_time, None).unwrap(); + g.assert_eq_norm( + &soln.solution_points[soln.solution_points.len() - 1].state, + problem.out_atol.as_ref().unwrap(), + problem.out_rtol.unwrap(), + 15.0, + ); + for (j, soln_pts) in soln.sens_solution_points.unwrap().iter().enumerate() { + gs_adj[j].assert_eq_norm( + &soln_pts[0].state, + problem.out_atol.as_ref().unwrap(), + problem.out_rtol.unwrap(), + 15.0, + ); + } + } }