Skip to content

Commit

Permalink
Always include timepoints in NaN/Inf warnings (#2347)
Browse files Browse the repository at this point in the history
Always include timepoints in NaN/Inf warnings

Closes #2328


---------

Co-authored-by: Fabian Fröhlich <[email protected]>
  • Loading branch information
dweindl and FFroehlich authored Mar 5, 2024
1 parent 2235ab7 commit 777deb0
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 61 deletions.
7 changes: 5 additions & 2 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1334,10 +1334,12 @@ class Model : public AbstractModel, public ModelDimensions {
*
* @param array
* @param model_quantity The model quantity `array` corresponds to
* @param t Current timepoint
* @return
*/
int checkFinite(
gsl::span<realtype const> array, ModelQuantity model_quantity
gsl::span<realtype const> array, ModelQuantity model_quantity,
realtype t
) const;
/**
* @brief Check if the given array has only finite elements.
Expand All @@ -1347,11 +1349,12 @@ class Model : public AbstractModel, public ModelDimensions {
* @param array Flattened matrix
* @param model_quantity The model quantity `array` corresponds to
* @param num_cols Number of columns of the non-flattened matrix
* @param t Current timepoint
* @return
*/
int checkFinite(
gsl::span<realtype const> array, ModelQuantity model_quantity,
size_t num_cols
size_t num_cols, realtype t
) const;

/**
Expand Down
4 changes: 2 additions & 2 deletions src/amici.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ std::unique_ptr<ReturnData> runAmiciSimulation(

try {
rdata->processSimulationObjects(
preeq.get(), fwd.get(), bwd_success ? bwd.get() : nullptr, posteq.get(),
model, solver, edata
preeq.get(), fwd.get(), bwd_success ? bwd.get() : nullptr,
posteq.get(), model, solver, edata
);
} catch (std::exception const& ex) {
rdata->status = AMICI_ERROR;
Expand Down
69 changes: 41 additions & 28 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ void Model::initializeStates(AmiVector& x) {
std::copy(x0_solver.cbegin(), x0_solver.cend(), x.data());
}

checkFinite(x.getVector(), ModelQuantity::x0);
checkFinite(x.getVector(), ModelQuantity::x0, t0());
}

void Model::initializeSplines() {
Expand Down Expand Up @@ -1482,7 +1482,7 @@ void Model::addStateEventUpdate(
);

if (always_check_finite_) {
checkFinite(derived_state_.deltax_, ModelQuantity::deltax);
checkFinite(derived_state_.deltax_, ModelQuantity::deltax, t);
}

// update
Expand Down Expand Up @@ -1536,7 +1536,7 @@ void Model::addAdjointStateEventUpdate(
);

if (always_check_finite_) {
checkFinite(derived_state_.deltaxB_, ModelQuantity::deltaxB);
checkFinite(derived_state_.deltaxB_, ModelQuantity::deltaxB, t);
}

// apply update
Expand Down Expand Up @@ -1582,7 +1582,7 @@ void Model::updateHeavisideB(int const* rootsfound) {
}

int Model::checkFinite(
gsl::span<realtype const> array, ModelQuantity model_quantity
gsl::span<realtype const> array, ModelQuantity model_quantity, realtype t
) const {
auto it = std::find_if(array.begin(), array.end(), [](realtype x) {
return !std::isfinite(x);
Expand Down Expand Up @@ -1654,31 +1654,35 @@ int Model::checkFinite(
gsl_ExpectsDebug(false);
model_quantity_str = std::to_string(static_cast<int>(model_quantity));
}
if (logger)
if (logger) {
auto t_msg = std::isfinite(t)
? std::string(" at t=" + std::to_string(t) + " ")
: std::string();

logger->log(
LogSeverity::warning, msg_id,
"AMICI encountered a %s value for %s[%i] (%s)",
"AMICI encountered a %s value for %s[%i] (%s)%s",
non_finite_type.c_str(), model_quantity_str.c_str(),
gsl::narrow<int>(flat_index), element_id.c_str()
gsl::narrow<int>(flat_index), element_id.c_str(), t_msg.c_str()
);

}
// check upstream, without infinite recursion
if (model_quantity != ModelQuantity::k && model_quantity != ModelQuantity::p
&& model_quantity != ModelQuantity::ts) {
checkFinite(state_.fixedParameters, ModelQuantity::k);
checkFinite(state_.unscaledParameters, ModelQuantity::p);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts);
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
if (!always_check_finite_ && model_quantity != ModelQuantity::w) {
// don't check twice if always_check_finite_ is true
checkFinite(derived_state_.w_, ModelQuantity::w);
checkFinite(derived_state_.w_, ModelQuantity::w, t);
}
}
return AMICI_RECOVERABLE_ERROR;
}

int Model::checkFinite(
gsl::span<realtype const> array, ModelQuantity model_quantity,
size_t num_cols
size_t num_cols, realtype t
) const {
auto it = std::find_if(array.begin(), array.end(), [](realtype x) {
return !std::isfinite(x);
Expand Down Expand Up @@ -1768,19 +1772,25 @@ int Model::checkFinite(
model_quantity_str = std::to_string(static_cast<int>(model_quantity));
}

if (logger)
if (logger) {
auto t_msg = std::isfinite(t)
? std::string(" at t=" + std::to_string(t) + " ")
: std::string();

logger->log(
LogSeverity::warning, msg_id,
"AMICI encountered a %s value for %s[%i] (%s, %s)",
"AMICI encountered a %s value for %s[%i] (%s, %s)%s",
non_finite_type.c_str(), model_quantity_str.c_str(),
gsl::narrow<int>(flat_index), row_id.c_str(), col_id.c_str()
gsl::narrow<int>(flat_index), row_id.c_str(), col_id.c_str(),
t_msg.c_str()
);
}

// check upstream
checkFinite(state_.fixedParameters, ModelQuantity::k);
checkFinite(state_.unscaledParameters, ModelQuantity::p);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts);
checkFinite(derived_state_.w_, ModelQuantity::w);
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
checkFinite(derived_state_.w_, ModelQuantity::w, t);

return AMICI_RECOVERABLE_ERROR;
}
Expand Down Expand Up @@ -1868,10 +1878,10 @@ int Model::checkFinite(SUNMatrix m, ModelQuantity model_quantity, realtype t)
);

// check upstream
checkFinite(state_.fixedParameters, ModelQuantity::k);
checkFinite(state_.unscaledParameters, ModelQuantity::p);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts);
checkFinite(derived_state_.w_, ModelQuantity::w);
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
checkFinite(derived_state_.w_, ModelQuantity::w, t);

return AMICI_RECOVERABLE_ERROR;
}
Expand All @@ -1895,7 +1905,7 @@ void Model::fx0(AmiVector& x) {
state_.unscaledParameters.data(), state_.fixedParameters.data()
);

checkFinite(derived_state_.x_rdata_, ModelQuantity::x0_rdata);
checkFinite(derived_state_.x_rdata_, ModelQuantity::x0_rdata, t0());
}

void Model::fx0_fixedParameters(AmiVector& x) {
Expand Down Expand Up @@ -1982,7 +1992,10 @@ void Model::fx_rdata(AmiVector& x_rdata, AmiVector const& x) {
state_.unscaledParameters.data(), state_.fixedParameters.data()
);
if (always_check_finite_)
checkFinite(x_rdata.getVector(), ModelQuantity::x_rdata);
checkFinite(
x_rdata.getVector(), ModelQuantity::x_rdata,
std::numeric_limits<realtype>::quiet_NaN()
);
}

void Model::fsx_rdata(
Expand Down Expand Up @@ -2072,7 +2085,7 @@ void Model::fy(realtype const t, AmiVector const& x) {

if (always_check_finite_) {
checkFinite(
gsl::make_span(derived_state_.y_.data(), ny), ModelQuantity::y
gsl::make_span(derived_state_.y_.data(), ny), ModelQuantity::y, t
);
}
}
Expand Down Expand Up @@ -2875,7 +2888,7 @@ void Model::fw(realtype const t, realtype const* x, bool include_static) {
state_.spl_.data(), include_static);

if (always_check_finite_) {
checkFinite(derived_state_.w_, ModelQuantity::w);
checkFinite(derived_state_.w_, ModelQuantity::w, t);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/model_dae.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ void Model_DAE::fJDiag(
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_);
derived_state_.J_.refresh();
derived_state_.J_.to_diag(JDiag.getNVector());
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag) != AMICI_SUCCESS)
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag, t)
!= AMICI_SUCCESS)
throw AmiException("Evaluation of fJDiag failed!");
}

Expand Down
3 changes: 2 additions & 1 deletion src/model_ode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ void Model_ODE::fJDiag(
AmiVector const& x, AmiVector const& /*dx*/
) {
fJDiag(t, JDiag.getNVector(), x.getNVector());
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag) != AMICI_SUCCESS)
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag, t)
!= AMICI_SUCCESS)
throw AmiException("Evaluation of fJDiag failed!");
}

Expand Down
34 changes: 19 additions & 15 deletions src/solver_cvodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ fJB(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot, SUNMatrix JB,
Expects(model);

model->fJB(t, x, xB, xBdot, JB);
return model->checkFinite(gsl::make_span(JB), ModelQuantity::JB);
return model->checkFinite(gsl::make_span(JB), ModelQuantity::JB, t);
}

/**
Expand Down Expand Up @@ -978,7 +978,7 @@ static int fJSparseB(
Expects(model);

model->fJSparseB(t, x, xB, xBdot, JB);
return model->checkFinite(gsl::make_span(JB), ModelQuantity::JB);
return model->checkFinite(gsl::make_span(JB), ModelQuantity::JB, t);
}

/**
Expand Down Expand Up @@ -1041,7 +1041,7 @@ fJv(N_Vector v, N_Vector Jv, realtype t, N_Vector x, N_Vector /*xdot*/,
Expects(model);

model->fJv(v, Jv, t, x);
return model->checkFinite(gsl::make_span(Jv), ModelQuantity::Jv);
return model->checkFinite(gsl::make_span(Jv), ModelQuantity::Jv, t);
}

/**
Expand All @@ -1067,7 +1067,7 @@ static int fJvB(
Expects(model);

model->fJvB(vB, JvB, t, x, xB);
return model->checkFinite(gsl::make_span(JvB), ModelQuantity::JvB);
return model->checkFinite(gsl::make_span(JvB), ModelQuantity::JvB, t);
}

/**
Expand All @@ -1094,7 +1094,7 @@ static int froot(realtype t, N_Vector x, realtype* root, void* user_data) {
model->froot(t, x, gsl::make_span<realtype>(root, model->ne_solver));
}
return model->checkFinite(
gsl::make_span<realtype>(root, model->ne_solver), ModelQuantity::root
gsl::make_span<realtype>(root, model->ne_solver), ModelQuantity::root, t
);
}

Expand All @@ -1119,7 +1119,7 @@ static int fxdot(realtype t, N_Vector x, N_Vector xdot, void* user_data) {
}

if (t > 1e200
&& !model->checkFinite(gsl::make_span(x), ModelQuantity::xdot)) {
&& !model->checkFinite(gsl::make_span(x), ModelQuantity::xdot, t)) {
/* when t is large (typically ~1e300), CVODES may pass all NaN x
to fxdot from which we typically cannot recover. To save time
on normal execution, we do not always want to check finiteness
Expand All @@ -1128,7 +1128,7 @@ static int fxdot(realtype t, N_Vector x, N_Vector xdot, void* user_data) {
}

model->fxdot(t, x, xdot);
return model->checkFinite(gsl::make_span(xdot), ModelQuantity::xdot);
return model->checkFinite(gsl::make_span(xdot), ModelQuantity::xdot, t);
}

/**
Expand All @@ -1154,7 +1154,7 @@ fxBdot(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot, void* user_data) {
}

model->fxBdot(t, x, xB, xBdot);
return model->checkFinite(gsl::make_span(xBdot), ModelQuantity::xBdot);
return model->checkFinite(gsl::make_span(xBdot), ModelQuantity::xBdot, t);
}

/**
Expand All @@ -1174,7 +1174,7 @@ fqBdot(realtype t, N_Vector x, N_Vector xB, N_Vector qBdot, void* user_data) {
Expects(model);

model->fqBdot(t, x, xB, qBdot);
return model->checkFinite(gsl::make_span(qBdot), ModelQuantity::qBdot);
return model->checkFinite(gsl::make_span(qBdot), ModelQuantity::qBdot, t);
}

/**
Expand All @@ -1193,7 +1193,9 @@ static int fxBdot_ss(realtype t, N_Vector xB, N_Vector xBdot, void* user_data) {
Expects(model);

model->fxBdot_ss(t, xB, xBdot);
return model->checkFinite(gsl::make_span(xBdot), ModelQuantity::xBdot_ss);
return model->checkFinite(
gsl::make_span(xBdot), ModelQuantity::xBdot_ss, t
);
}

/**
Expand All @@ -1212,7 +1214,9 @@ static int fqBdot_ss(realtype t, N_Vector xB, N_Vector qBdot, void* user_data) {
Expects(model);

model->fqBdot_ss(t, xB, qBdot);
return model->checkFinite(gsl::make_span(qBdot), ModelQuantity::qBdot_ss);
return model->checkFinite(
gsl::make_span(qBdot), ModelQuantity::qBdot_ss, t
);
}

/**
Expand All @@ -1228,8 +1232,8 @@ static int fqBdot_ss(realtype t, N_Vector xB, N_Vector qBdot, void* user_data) {
* @return status flag indicating successful execution
*/
static int fJSparseB_ss(
realtype /*t*/, N_Vector /*x*/, N_Vector xBdot, SUNMatrix JB,
void* user_data, N_Vector /*tmp1*/, N_Vector /*tmp2*/, N_Vector /*tmp3*/
realtype t, N_Vector /*x*/, N_Vector xBdot, SUNMatrix JB, void* user_data,
N_Vector /*tmp1*/, N_Vector /*tmp2*/, N_Vector /*tmp3*/
) {
auto typed_udata = static_cast<CVodeSolver::user_data_type*>(user_data);
Expects(typed_udata);
Expand All @@ -1238,7 +1242,7 @@ static int fJSparseB_ss(

model->fJSparseB_ss(JB);
return model->checkFinite(
gsl::make_span(xBdot), ModelQuantity::JSparseB_ss
gsl::make_span(xBdot), ModelQuantity::JSparseB_ss, t
);
}

Expand Down Expand Up @@ -1267,7 +1271,7 @@ static int fsxdot(
Expects(model);

model->fsxdot(t, x, ip, sx, sxdot);
return model->checkFinite(gsl::make_span(sxdot), ModelQuantity::sxdot);
return model->checkFinite(gsl::make_span(sxdot), ModelQuantity::sxdot, t);
}

bool operator==(CVodeSolver const& a, CVodeSolver const& b) {
Expand Down
Loading

0 comments on commit 777deb0

Please sign in to comment.