Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RDataReporting::observables_likelihood #2627

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/amici/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ enum class RDataReporting {
full,
residuals,
likelihood,
observables_likelihood,
};

/** boundary conditions for splines */
Expand Down
7 changes: 7 additions & 0 deletions include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ class ReturnData : public ModelDimensions {
*/
void initializeLikelihoodReporting(bool quadratic_llh);

/**
* @brief initializes storage for observables + likelihood reporting mode
* @param quadratic_llh whether model defines a quadratic nllh and computing
* res, sres and FIM makes sense.
*/
void initializeObservablesLikelihoodReporting(bool quadratic_llh);

/**
* @brief initializes storage for residual reporting mode
* @param enable_res whether residuals are to be computed
Expand Down
41 changes: 41 additions & 0 deletions python/tests/test_swig_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,44 @@ def test_python_exceptions(sbml_example_presimulation_module):
):
# rethrow=True
runAmiciSimulation(solver, None, model.get(), True)


def test_reporting_mode_obs_llh(sbml_example_presimulation_module):
model_module = sbml_example_presimulation_module
model = model_module.getModel()
solver = model.getSolver()

solver.setReturnDataReportingMode(
amici.RDataReporting.observables_likelihood
)
solver.setSensitivityOrder(amici.SensitivityOrder.first)

for sens_method in (
amici.SensitivityMethod.none,
amici.SensitivityMethod.forward,
amici.SensitivityMethod.adjoint,
):
solver.setSensitivityMethod(sens_method)
rdata = amici.runAmiciSimulation(
model, solver, amici.ExpData(1, 1, 1, [1])
)
assert (
rdata.rdata_reporting
== amici.RDataReporting.observables_likelihood
)

assert rdata.y.size > 0
assert rdata.sigmay.size > 0
assert rdata.J is None

match solver.getSensitivityMethod():
case amici.SensitivityMethod.none:
assert rdata.sllh is None
case amici.SensitivityMethod.forward:
assert rdata.sy.size > 0
assert rdata.ssigmay.size > 0
assert rdata.sllh.size > 0
case amici.SensitivityMethod.adjoint:
assert rdata.sy is None
assert rdata.ssigmay is None
assert rdata.sllh.size > 0
21 changes: 20 additions & 1 deletion src/rdata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@
case RDataReporting::likelihood:
initializeLikelihoodReporting(quadratic_llh);
break;

case RDataReporting::observables_likelihood:
initializeObservablesLikelihoodReporting(quadratic_llh);
break;
}
}

void ReturnData::initializeLikelihoodReporting(bool enable_fim) {
llh = getNaN();
chi2 = getNaN();
if (sensi >= SensitivityOrder::first) {
if (sensi >= SensitivityOrder::first && sensi_meth != SensitivityMethod::none) {
sllh.resize(nplist, getNaN());
if (sensi >= SensitivityOrder::second)
s2llh.resize(nplist * (nJ - 1), getNaN());
Expand All @@ -78,6 +82,21 @@
}
}

void ReturnData::initializeObservablesLikelihoodReporting(bool enable_fim) {
initializeLikelihoodReporting(enable_fim);

y.resize(nt * ny, 0.0);
sigmay.resize(nt * ny, 0.0);

if ((sensi_meth == SensitivityMethod::forward
&& sensi >= SensitivityOrder::first)
|| sensi >= SensitivityOrder::second) {

sy.resize(nt * ny * nplist, 0.0);
ssigmay.resize(nt * ny * nplist, 0.0);
}
}

Check warning on line 98 in src/rdata.cpp

View check run for this annotation

Codecov / codecov/patch

src/rdata.cpp#L98

Added line #L98 was not covered by tests

void ReturnData::initializeResidualReporting(bool enable_res) {
y.resize(nt * ny, 0.0);
sigmay.resize(nt * ny, 0.0);
Expand Down
Loading