From 9f562667ba74d1dfdd86514f89155644daeec4ba Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Sun, 20 Oct 2024 10:08:17 +0200 Subject: [PATCH] Fix crashes when errors occur at output timepoints (#2555) Fixes a bug that lead to program termination if a root-after-reinitialization error (potentially also others) occurred at an output timepoint, because an non-existing/invalid SimulationState for that timepoint was accessed. See #2491 for further details. Fixes #2491. Also avoid some unnecessary copying (during which previously the segfault occurred if this bug triggered in non-debug builds). --- include/amici/forwardproblem.h | 17 +++++++++++++++-- src/rdata.cpp | 6 +++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/include/amici/forwardproblem.h b/include/amici/forwardproblem.h index ffcca63666..40a4147f4e 100644 --- a/include/amici/forwardproblem.h +++ b/include/amici/forwardproblem.h @@ -2,6 +2,7 @@ #define AMICI_FORWARDPROBLEM_H #include "amici/defines.h" +#include "amici/edata.h" #include "amici/misc.h" #include "amici/model.h" #include "amici/vector.h" @@ -195,7 +196,7 @@ class ForwardProblem { if (model->getTimepoint(it) == initial_state_.t) return getInitialSimulationState(); auto map_iter = timepoint_states_.find(model->getTimepoint(it)); - assert(map_iter != timepoint_states_.end()); + Ensures(map_iter != timepoint_states_.end()); return map_iter->second; }; @@ -441,8 +442,20 @@ class FinalStateStorer : public ContextManager { * @brief destructor, stores simulation state */ ~FinalStateStorer() { - if (fwd_) + if (fwd_) { fwd_->final_state_ = fwd_->getSimulationState(); + // if there is an associated output timepoint, also store it in + // timepoint_states if it's not present there. + // this may happen if there is an error just at + // (or indistinguishably before) an output timepoint + auto final_time = fwd_->getFinalTime(); + auto const timepoints = fwd_->model->getTimepoints(); + if (!fwd_->timepoint_states_.count(final_time) + && std::find(timepoints.cbegin(), timepoints.cend(), final_time) + != timepoints.cend()) { + fwd_->timepoint_states_[final_time] = fwd_->final_state_; + } + } } private: diff --git a/src/rdata.cpp b/src/rdata.cpp index a2fa02dd62..4ec983af2b 100644 --- a/src/rdata.cpp +++ b/src/rdata.cpp @@ -241,7 +241,7 @@ void ReturnData::processForwardProblem( if (edata) initializeObjectiveFunction(model.hasQuadraticLLH()); - auto initialState = fwd.getInitialSimulationState(); + auto const& initialState = fwd.getInitialSimulationState(); if (initialState.x.getLength() == 0 && model.nx_solver > 0) return; // if x wasn't set forward problem failed during initialization @@ -259,7 +259,7 @@ void ReturnData::processForwardProblem( realtype tf = fwd.getFinalTime(); for (int it = 0; it < model.nt(); it++) { if (model.getTimepoint(it) <= tf) { - auto simulation_state = fwd.getSimulationStateTimepoint(it); + auto const simulation_state = fwd.getSimulationStateTimepoint(it); model.setModelState(simulation_state.state); getDataOutput(it, model, simulation_state, edata); } else { @@ -273,7 +273,7 @@ void ReturnData::processForwardProblem( if (nz > 0) { auto rootidx = fwd.getRootIndexes(); for (int iroot = 0; iroot <= fwd.getEventCounter(); iroot++) { - auto simulation_state = fwd.getSimulationStateEvent(iroot); + auto const simulation_state = fwd.getSimulationStateEvent(iroot); model.setModelState(simulation_state.state); getEventOutput( simulation_state.t, rootidx.at(iroot), model, simulation_state,