Skip to content

Commit

Permalink
Allow subselection of state variables for steady-state simulations
Browse files Browse the repository at this point in the history
Closes #2368
  • Loading branch information
dweindl committed Mar 26, 2024
1 parent 790ab44 commit c6528c4
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 9 deletions.
43 changes: 43 additions & 0 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,40 @@ class Model : public AbstractModel, public ModelDimensions {
*/
virtual std::vector<double> get_trigger_timepoints() const;

/**
* @brief Get steady-state mask as std::vector.
*
* See `set_steadystate_mask` for details.
*
* @return Steady-state mask
*/
std::vector<double> get_steadystate_mask() const {
return steadystate_mask_.getVector();

Check warning on line 1492 in include/amici/model.h

View check run for this annotation

Codecov / codecov/patch

include/amici/model.h#L1491-L1492

Added lines #L1491 - L1492 were not covered by tests
};

/**
* @brief Get steady-state mask as AmiVector.
*
* See `set_steadystate_mask` for details.
* @return Steady-state mask
*/
AmiVector const& get_steadystate_mask_av() const {
return steadystate_mask_;
};

/**
* @brief Set steady-state mask.
*
* The mask is used to exclude certain state variables from the steady-state
* convergence check. Positive values indicate that the corresponding state
* variable should be included in the convergence check, while non-positive
* values indicate that the corresponding state variable should be excluded.
* An empty mask is interpreted as including all state variables.
*
* @param mask Mask of length `nx_solver`.
*/
void set_steadystate_mask(std::vector<double> const& mask);

/**
* Flag indicating whether for
* `amici::Solver::sensi_` == `amici::SensitivityOrder::second`
Expand Down Expand Up @@ -2087,6 +2121,15 @@ class Model : public AbstractModel, public ModelDimensions {

/** Simulation parameters, initial state, etc. */
SimulationParameters simulation_parameters_;

/**
* Mask for state variables that should be checked for steady state
* during pre-/post-equilibration. Positive values indicate that the
* corresponding state variable should be checked for steady state.
* Negative values indicate that the corresponding state variable should
* be ignored.
*/
AmiVector steadystate_mask_;
};

bool operator==(Model const& a, Model const& b);
Expand Down
6 changes: 4 additions & 2 deletions include/amici/steadystateproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,16 @@ class SteadystateProblem {
* w_i = 1 / ( rtol * x_i + atol )
* @param x current state (sx[ip] for sensitivities)
* @param xdot current rhs (sxdot[ip] for sensitivities)
* @param mask mask for state variables to include in WRMS norm.
* Positive value: include; non-positive value: exclude; empty: include all.
* @param atol absolute tolerance
* @param rtol relative tolerance
* @param ewt error weight vector
* @return root-mean-square norm
*/
realtype getWrmsNorm(
AmiVector const& x, AmiVector const& xdot, realtype atol, realtype rtol,
AmiVector& ewt
AmiVector const& x, AmiVector const& xdot, AmiVector const& mask,
realtype atol, realtype rtol, AmiVector& ewt
) const;

/**
Expand Down
59 changes: 59 additions & 0 deletions python/tests/test_preequilibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from amici.debugging import get_model_for_preeq
from numpy.testing import assert_allclose, assert_equal
from test_pysb import get_data
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory


@pytest.fixture
Expand Down Expand Up @@ -658,3 +659,61 @@ def test_get_model_for_preeq(preeq_fixture):
rdata1.sx,
rdata2.sx,
)


def test_partial_eq():
"""Check that partial equilibration is possible."""
from amici.antimony_import import antimony2amici

ant_str = """
model test_partial_eq
explodes = 1
explodes' = explodes
A = 1
B = 0
R: A -> B; k*A - k*B
k = 1
end
"""
module_name = "test_partial_eq"
with TemporaryDirectory(prefix=module_name) as outdir:
antimony2amici(
ant_str,
model_name=module_name,
output_dir=outdir,
)
model_module = amici.import_model_module(
module_name=module_name, module_path=outdir
)
amici_model = model_module.getModel()
amici_model.setTimepoints([np.inf])
amici_solver = amici_model.getSolver()
amici_solver.setRelativeToleranceSteadyState(1e-12)

# equilibration of `explodes` will fail
rdata = amici.runAmiciSimulation(amici_model, amici_solver)
assert rdata.status == amici.AMICI_ERROR
assert rdata.messages[0].identifier == "EQUILIBRATION_FAILURE"

# excluding `explodes` should enable equilibration
amici_model.set_steadystate_mask(
[
0 if state_id == "explodes" else 1
for state_id in amici_model.getStateIdsSolver()
]
)
rdata = amici.runAmiciSimulation(amici_model, amici_solver)
assert rdata.status == amici.AMICI_SUCCESS
assert_allclose(
rdata.by_id("A"),
0.5,
atol=amici_solver.getAbsoluteToleranceSteadyState(),
rtol=amici_solver.getRelativeToleranceSteadyState(),
)
assert_allclose(
rdata.by_id("B"),
0.5,
atol=amici_solver.getAbsoluteToleranceSteadyState(),
rtol=amici_solver.getRelativeToleranceSteadyState(),
)
assert rdata.t_last < 100
17 changes: 17 additions & 0 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3138,6 +3138,23 @@ std::vector<double> Model::get_trigger_timepoints() const {
return trigger_timepoints;
}

void Model::set_steadystate_mask(std::vector<double> const& mask) {

Check warning on line 3141 in src/model.cpp

View check run for this annotation

Codecov / codecov/patch

src/model.cpp#L3141

Added line #L3141 was not covered by tests
if (mask.size() == 0) {
if (steadystate_mask_.getLength() != 0) {
steadystate_mask_ = AmiVector();

Check warning on line 3144 in src/model.cpp

View check run for this annotation

Codecov / codecov/patch

src/model.cpp#L3144

Added line #L3144 was not covered by tests
}
return;

Check warning on line 3146 in src/model.cpp

View check run for this annotation

Codecov / codecov/patch

src/model.cpp#L3146

Added line #L3146 was not covered by tests
}

if (gsl::narrow<int>(mask.size()) != nx_solver)
throw AmiException(

Check warning on line 3150 in src/model.cpp

View check run for this annotation

Codecov / codecov/patch

src/model.cpp#L3150

Added line #L3150 was not covered by tests
"Steadystate mask has wrong size: %d, expected %d",
gsl::narrow<int>(mask.size()), nx_solver
);

Check warning on line 3153 in src/model.cpp

View check run for this annotation

Codecov / codecov/patch

src/model.cpp#L3153

Added line #L3153 was not covered by tests

steadystate_mask_ = AmiVector(mask);

Check warning on line 3155 in src/model.cpp

View check run for this annotation

Codecov / codecov/patch

src/model.cpp#L3155

Added line #L3155 was not covered by tests
}

const_N_Vector Model::computeX_pos(const_N_Vector x) {
if (any_state_non_negative_) {
for (int ix = 0; ix < derived_state_.x_pos_tmp_.getLength(); ++ix) {
Expand Down
27 changes: 20 additions & 7 deletions src/steadystateproblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ bool SteadystateProblem::getSensitivityFlag(
}

realtype SteadystateProblem::getWrmsNorm(
AmiVector const& x, AmiVector const& xdot, realtype atol, realtype rtol,
AmiVector& ewt
AmiVector const& x, AmiVector const& xdot, AmiVector const& mask,
realtype atol, realtype rtol, AmiVector& ewt
) const {
/* Depending on what convergence we want to check (xdot, sxdot, xQBdot)
we need to pass ewt[QB], as xdot and xQBdot have different sizes */
Expand All @@ -522,7 +522,14 @@ realtype SteadystateProblem::getWrmsNorm(
N_VAddConst(ewt.getNVector(), atol, ewt.getNVector());
/* ewt = 1/ewt (ewt = 1/(rtol*x+atol)) */
N_VInv(ewt.getNVector(), ewt.getNVector());
/* wrms = sqrt(sum((xdot/ewt)**2)/n) where n = size of state vector */

// wrms = sqrt(sum((xdot/ewt)**2)/n) where n = size of state vector
if (mask.getLength()) {
return N_VWrmsNormMask(
const_cast<N_Vector>(xdot.getNVector()), ewt.getNVector(),
const_cast<N_Vector>(mask.getNVector())
);

Check warning on line 531 in src/steadystateproblem.cpp

View check run for this annotation

Codecov / codecov/patch

src/steadystateproblem.cpp#L528-L531

Added lines #L528 - L531 were not covered by tests
}
return N_VWrmsNorm(
const_cast<N_Vector>(xdot.getNVector()), ewt.getNVector()
);
Expand All @@ -543,7 +550,10 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
"Newton type convergence check is not implemented for adjoint "
"steady state computations. Stopping."
);
wrms = getWrmsNorm(xQB_, xQBdot_, atol_quad_, rtol_quad_, ewtQB_);
wrms = getWrmsNorm(
xQB_, xQBdot_, model.get_steadystate_mask_av(), atol_quad_,
rtol_quad_, ewtQB_

Check warning on line 555 in src/steadystateproblem.cpp

View check run for this annotation

Codecov / codecov/patch

src/steadystateproblem.cpp#L553-L555

Added lines #L553 - L555 were not covered by tests
);
} else {
/* If we're doing a forward simulation (with or without sensitivities:
Get RHS and compute weighted error norm */
Expand All @@ -552,7 +562,8 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
else
updateRightHandSide(model);
wrms = getWrmsNorm(
state_.x, newton_step_conv_ ? delta_ : xdot_, atol_, rtol_, ewt_
state_.x, newton_step_conv_ ? delta_ : xdot_,
model.get_steadystate_mask_av(), atol_, rtol_, ewt_
);
}
return wrms;
Expand All @@ -573,8 +584,10 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) {
);
if (newton_step_conv_)
newton_solver_->solveLinearSystem(xdot_);
wrms
= getWrmsNorm(state_.sx[ip], xdot_, atol_sensi_, rtol_sensi_, ewt_);
wrms = getWrmsNorm(
state_.sx[ip], xdot_, model.get_steadystate_mask_av(), atol_sensi_,
rtol_sensi_, ewt_
);
/* ideally this function would report the maximum of all wrms over
all ip, but for practical purposes we can just report the wrms for
the first ip where we know that the convergence threshold is not
Expand Down

0 comments on commit c6528c4

Please sign in to comment.