Skip to content

Commit

Permalink
Merge branch 'develop' into update_sundials
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Oct 3, 2024
2 parents 9590335 + f66f0b0 commit 271b057
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 299 deletions.
139 changes: 23 additions & 116 deletions include/amici/newton_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
#define amici_newton_solver_h

#include "amici/solver.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"

#include <memory>

namespace amici {

class Model;
Expand All @@ -26,36 +23,33 @@ class NewtonSolver {
* @brief Initializes solver according to the dimensions in the provided
* model
*
* @param model pointer to the model object
* @param model the model object
* @param linsol_type type of linear solver to use
* @param sunctx SUNDIALS context
*/
explicit NewtonSolver(Model const& model, SUNContext sunctx);
explicit NewtonSolver(
Model const& model, LinearSolver linsol_type, SUNContext sunctx
);

/**
* @brief Factory method to create a NewtonSolver based on linsolType
*
* @param simulationSolver solver with settings
* @param model pointer to the model instance
* @return NewtonSolver according to the specified linsolType
*/
static std::unique_ptr<NewtonSolver>
getSolver(Solver const& simulationSolver, Model const& model);
NewtonSolver(NewtonSolver const&) = delete;

NewtonSolver& operator=(NewtonSolver const& other) = delete;

/**
* @brief Computes the solution of one Newton iteration
*
* @param delta containing the RHS of the linear system, will be
* overwritten by solution to the linear system
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
*/
void getStep(AmiVector& delta, Model& model, SimulationState const& state);

/**
* @brief Computes steady state sensitivities
*
* @param sx pointer to state variable sensitivities
* @param model pointer to the model instance
* @param sx state variable sensitivities
* @param model the model instance
* @param state current simulation state
*/
void computeNewtonSensis(
Expand All @@ -66,51 +60,45 @@ class NewtonSolver {
* @brief Writes the Jacobian for the Newton iteration and passes it to the
* linear solver
*
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
*/
virtual void prepareLinearSystem(Model& model, SimulationState const& state)
= 0;
void prepareLinearSystem(Model& model, SimulationState const& state);

/**
* Writes the Jacobian (JB) for the Newton iteration and passes it to the
* linear solver
*
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
*/
virtual void
prepareLinearSystemB(Model& model, SimulationState const& state)
= 0;
void prepareLinearSystemB(Model& model, SimulationState const& state);

/**
* @brief Solves the linear system for the Newton step
*
* @param rhs containing the RHS of the linear system, will be
* overwritten by solution to the linear system
*/
virtual void solveLinearSystem(AmiVector& rhs) = 0;
void solveLinearSystem(AmiVector& rhs);

/**
* @brief Reinitialize the linear solver
*
*/
virtual void reinitialize() = 0;
void reinitialize();

/**
* @brief Checks whether linear system is singular
* @brief Checks whether the linear system is singular
*
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
* @return boolean indicating whether the linear system is singular
* (condition number < 1/machine precision)
*/
virtual bool is_singular(Model& model, SimulationState const& state) const
= 0;

virtual ~NewtonSolver() = default;
bool is_singular(Model& model, SimulationState const& state) const;

protected:
private:
/** dummy rhs, used as dummy argument when computing J and JB */
AmiVector xdot_;
/** dummy state, attached to linear solver */
Expand All @@ -120,90 +108,9 @@ class NewtonSolver {
/** dummy differential adjoint state, used as dummy argument when computing
* JB */
AmiVector dxB_;
};

/**
* @brief The NewtonSolverDense provides access to the dense linear solver for
* the Newton method.
*/

class NewtonSolverDense : public NewtonSolver {

public:
/**
* @brief constructor for sparse solver
*
* @param model model instance that provides problem dimensions
* @param sunctx SUNDIALS context
*/
explicit NewtonSolverDense(Model const& model, SUNContext sunctx);

NewtonSolverDense(NewtonSolverDense const&) = delete;

NewtonSolverDense& operator=(NewtonSolverDense const& other) = delete;

~NewtonSolverDense() override;

void solveLinearSystem(AmiVector& rhs) override;

void
prepareLinearSystem(Model& model, SimulationState const& state) override;

void
prepareLinearSystemB(Model& model, SimulationState const& state) override;

void reinitialize() override;

bool is_singular(Model& model, SimulationState const& state) const override;

private:
/** temporary storage of Jacobian */
SUNMatrixWrapper Jtmp_;

/** dense linear solver */
SUNLinearSolver linsol_{nullptr};
};

/**
* @brief The NewtonSolverSparse provides access to the sparse linear solver for
* the Newton method.
*/

class NewtonSolverSparse : public NewtonSolver {

public:
/**
* @brief constructor for dense solver
*
* @param model model instance that provides problem dimensions
* @param sunctx SUNDIALS context
*/
explicit NewtonSolverSparse(Model const& model, SUNContext sunctx);

NewtonSolverSparse(NewtonSolverSparse const&) = delete;

NewtonSolverSparse& operator=(NewtonSolverSparse const& other) = delete;

~NewtonSolverSparse() override;

void solveLinearSystem(AmiVector& rhs) override;

void
prepareLinearSystem(Model& model, SimulationState const& state) override;

void
prepareLinearSystemB(Model& model, SimulationState const& state) override;

bool is_singular(Model& model, SimulationState const& state) const override;

void reinitialize() override;

private:
/** temporary storage of Jacobian */
SUNMatrixWrapper Jtmp_;

/** sparse linear solver */
SUNLinearSolver linsol_{nullptr};
/** linear solver */
std::unique_ptr<SUNLinSolWrapper> linsol_;
};

} // namespace amici
Expand Down
2 changes: 1 addition & 1 deletion include/amici/steadystateproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ class SteadystateProblem {
realtype rtol_quad_{NAN};

/** newton solver */
std::unique_ptr<NewtonSolver> newton_solver_{nullptr};
NewtonSolver newton_solver_;

/** damping factor flag */
NewtonDampingFactorMode damping_factor_mode_{NewtonDampingFactorMode::on};
Expand Down
8 changes: 8 additions & 0 deletions include/amici/sundials_linsol_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
void setOrdering(StateOrdering ordering);

/**
* @brief Checks whether the linear system is singular
*
* @return boolean indicating whether the linear system is singular
* (condition number < 1/machine precision)
*/
bool is_singular() const;
};

#ifdef SUNDIALS_SUPERLUMT
Expand Down
23 changes: 19 additions & 4 deletions python/examples/example_steadystate/ExampleSteadystate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,10 @@
"source": [
"model = model_module.getModel()\n",
"\n",
"print(\"Model name: \", model.getName())\n",
"print(\"Model parameters:\", model.getParameterIds())\n",
"print(\"Model outputs: \", model.getObservableIds())\n",
"print(\"Model states: \", model.getStateIds())"
"print(\"Model name: \", model.getName())\n",
"print(\"Model parameters: \", model.getParameterIds())\n",
"print(\"Model outputs: \", model.getObservableIds())\n",
"print(\"Model state variables: \", model.getStateIds())"
]
},
{
Expand Down Expand Up @@ -985,6 +985,21 @@
"print(\"Log-likelihood %f\" % rdata[\"llh\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "The provided measurements can be visualized together with the simulation results by passing the `Expdata` to `amici.plotting.plot_observable_trajectories`:"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"amici.plotting.plot_observable_trajectories(rdata, edata=edata)\n",
"plt.legend(loc=\"center left\", bbox_to_anchor=(1.04, 0.5))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit 271b057

Please sign in to comment.