Skip to content

Commit

Permalink
Split expressions into static and dynamic (#2303)
Browse files Browse the repository at this point in the history
Split expressions in `w` and its derivatives into dynamic (explicitly or implicitly time-dependent) and static ones.
Evaluate static ones only when needed, i.e. after (re)initializing x_rdata or parameters.

See #1269
  • Loading branch information
dweindl authored Feb 27, 2024
1 parent 6a1cc3c commit eb3fd4a
Show file tree
Hide file tree
Showing 59 changed files with 999 additions and 645 deletions.
3 changes: 3 additions & 0 deletions .github/actions/setup-amici-cpp/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ runs:
- run: echo "ENABLE_GCOV_COVERAGE=TRUE" >> $GITHUB_ENV
shell: bash

- run: echo "PYTHONFAULTHANDLER=1" >> $GITHUB_ENV
shell: bash

- name: Set up Sonar tools
uses: ./.github/actions/setup-sonar-tools

Expand Down
46 changes: 20 additions & 26 deletions include/amici/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -820,11 +820,15 @@ class AbstractModel {
* @param h Heaviside vector
* @param tcl total abundances for conservation laws
* @param spl spline value vector
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
virtual void
fw(realtype* w, realtype const t, realtype const* x, realtype const* p,
realtype const* k, realtype const* h, realtype const* tcl,
realtype const* spl);
realtype const* spl, bool include_static = true);

/**
* @brief Model-specific sparse implementation of dwdp
Expand All @@ -840,12 +844,16 @@ class AbstractModel {
* @param spl spline value vector
* @param sspl sensitivities of spline values vector w.r.t. parameters \f$ p
* \f$
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
virtual void fdwdp(
realtype* dwdp, realtype const t, realtype const* x, realtype const* p,
realtype const* k, realtype const* h, realtype const* w,
realtype const* tcl, realtype const* stcl, realtype const* spl,
realtype const* sspl
realtype const* sspl, bool include_static = true
);

/**
Expand All @@ -860,28 +868,6 @@ class AbstractModel {
*/
virtual void fdwdp_rowvals(SUNMatrixWrapper& dwdp);

/**
* @brief Model-specific sensitivity implementation of dwdp
* @param dwdp Recurring terms in xdot, parameter derivative
* @param t timepoint
* @param x vector with the states
* @param p parameter vector
* @param k constants vector
* @param h Heaviside vector
* @param w vector with helper variables
* @param tcl total abundances for conservation laws
* @param stcl sensitivities of total abundances for conservation laws
* @param spl spline value vector
* @param sspl sensitivities of spline values vector
* @param ip sensitivity parameter index
*/
virtual void fdwdp(
realtype* dwdp, realtype const t, realtype const* x, realtype const* p,
realtype const* k, realtype const* h, realtype const* w,
realtype const* tcl, realtype const* stcl, realtype const* spl,
realtype const* sspl, int ip
);

/**
* @brief Model-specific implementation of dwdx, data part
* @param dwdx Recurring terms in xdot, state derivative
Expand All @@ -893,11 +879,15 @@ class AbstractModel {
* @param w vector with helper variables
* @param tcl total abundances for conservation laws
* @param spl spline value vector
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
virtual void fdwdx(
realtype* dwdx, realtype const t, realtype const* x, realtype const* p,
realtype const* k, realtype const* h, realtype const* w,
realtype const* tcl, realtype const* spl
realtype const* tcl, realtype const* spl, bool include_static = true
);

/**
Expand All @@ -922,11 +912,15 @@ class AbstractModel {
* @param h Heaviside vector
* @param w vector with helper variables
* @param tcl Total abundances for conservation laws
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
virtual void fdwdw(
realtype* dwdw, realtype t, realtype const* x, realtype const* p,
realtype const* k, realtype const* h, realtype const* w,
realtype const* tcl
realtype const* tcl, bool include_static = true
);

/**
Expand Down
36 changes: 32 additions & 4 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ class Model : public AbstractModel, public ModelDimensions {
bool computeSensitivities, std::vector<int>& roots_found
);

/**
* @brief Re-initialize model properties after changing simulation context.
* @param t Timepoint
* @param x Reference to state variables
* @param sx Reference to state variable sensitivities
* @param computeSensitivities Flag indicating whether sensitivities are to
* be computed
*/
void reinitialize(
realtype t, AmiVector& x, AmiVectorArray& sx, bool computeSensitivities
);

/**
* @brief Initialize model properties.
* @param xB Adjoint state variables
Expand Down Expand Up @@ -1828,29 +1840,45 @@ class Model : public AbstractModel, public ModelDimensions {
* @brief Compute recurring terms in xdot.
* @param t Timepoint
* @param x Array with the states
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
void fw(realtype t, realtype const* x);
void fw(realtype t, realtype const* x, bool include_static = true);

/**
* @brief Compute parameter derivative for recurring terms in xdot.
* @param t Timepoint
* @param x Array with the states
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
void fdwdp(realtype t, realtype const* x);
void fdwdp(realtype t, realtype const* x, bool include_static = true);

/**
* @brief Compute state derivative for recurring terms in xdot.
* @param t Timepoint
* @param x Array with the states
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
void fdwdx(realtype t, realtype const* x);
void fdwdx(realtype t, realtype const* x, bool include_static = true);

/**
* @brief Compute self derivative for recurring terms in xdot.
* @param t Timepoint
* @param x Array with the states
* @param include_static Whether to (re-)evaluate only dynamic expressions
* (false) or also static expressions (true).
* Dynamic expressions are those that depend directly or indirectly on time,
* static expressions are those that don't.
*/
void fdwdw(realtype t, realtype const* x);
void fdwdw(realtype t, realtype const* x, bool include_static = true);

/**
* @brief Compute fx_rdata.
Expand Down
6 changes: 3 additions & 3 deletions matlab/@amifun/getArgs.m
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@
case 'dJrzdsigma'
this.argstr = '(double *dJrzdsigma, const int iz, const realtype *p, const realtype *k, const double *rz, const double *sigmaz)';
case 'w'
this.argstr = '(realtype *w, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl, const realtype *spl)';
this.argstr = '(realtype *w, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl, const realtype *spl, bool include_static)';
case 'dwdp'
this.argstr = '(realtype *dwdp, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl, const realtype *stcl, const realtype *spl, const realtype *sspl)';
this.argstr = '(realtype *dwdp, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl, const realtype *stcl, const realtype *spl, const realtype *sspl, bool include_static)';
case 'dwdx'
this.argstr = '(realtype *dwdx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl, const realtype *spl)';
this.argstr = '(realtype *dwdx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl, const realtype *spl, bool include_static)';
case 'M'
this.argstr = '(realtype *M, const realtype t, const realtype *x, const realtype *p, const realtype *k)';
otherwise
Expand Down
3 changes: 2 additions & 1 deletion matlab/@amimodel/generateC.m
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ function generateC(this)
end
fprintf(fid,'};\n\n');
fprintf(fid,['} // namespace model_' this.modelname '\n\n']);
fprintf(fid,'} // namespace amici \n\n');
fprintf(fid,'} // namespace amici\n\n');
fprintf(fid,['#endif /* _amici_' this.modelname '_h */\n']);
fclose(fid);

Expand Down Expand Up @@ -253,6 +253,7 @@ function generateC(this)

argstr = strrep(argstr,'realtype','');
argstr = strrep(argstr,'int','');
argstr = strrep(argstr,'bool','');
argstr = strrep(argstr,'const','');
argstr = strrep(argstr,'double','');
argstr = strrep(argstr,'SUNMatrixContent_Sparse','');
Expand Down
37 changes: 24 additions & 13 deletions models/model_calvetti/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Build AMICI model
cmake_minimum_required(VERSION 3.15)
cmake_policy(VERSION 3.15...3.27)

# cmake >=3.27
if(POLICY CMP0144)
cmake_policy(SET CMP0144 NEW)
endif(POLICY CMP0144)

project(model_calvetti)

Expand All @@ -14,7 +20,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif()
foreach(flag ${MY_CXX_FLAGS})
unset(CUR_FLAG_SUPPORTED CACHE)
check_cxx_compiler_flag(-Werror ${flag} CUR_FLAG_SUPPORTED)
check_cxx_compiler_flag(${flag} CUR_FLAG_SUPPORTED)
if(${CUR_FLAG_SUPPORTED})
string(APPEND CMAKE_CXX_FLAGS " ${flag}")
endif()
Expand All @@ -33,6 +39,23 @@ find_package(Amici REQUIRED HINTS
${CMAKE_CURRENT_LIST_DIR}/../../build)
message(STATUS "Found AMICI ${Amici_DIR}")

# Debug build?
if("$ENV{ENABLE_AMICI_DEBUGGING}" OR "$ENV{ENABLE_GCOV_COVERAGE}")
add_compile_options(-UNDEBUG)
if(MSVC)
add_compile_options(-DEBUG)
else()
add_compile_options(-O0 -g)
endif()
set(CMAKE_BUILD_TYPE "Debug")
endif()

# coverage options
if($ENV{ENABLE_GCOV_COVERAGE})
string(APPEND CMAKE_CXX_FLAGS_DEBUG " --coverage")
string(APPEND CMAKE_EXE_LINKER_FLAGS_DEBUG " --coverage")
endif()

set(MODEL_DIR ${CMAKE_CURRENT_LIST_DIR})

set(SRC_LIST_LIB ${MODEL_DIR}/JSparse.cpp
Expand Down Expand Up @@ -73,18 +96,6 @@ if(NOT "${AMICI_PYTHON_BUILD_EXT_ONLY}")
target_link_libraries(simulate_${PROJECT_NAME} ${PROJECT_NAME})
endif()

# Debug build?
if("$ENV{ENABLE_AMICI_DEBUGGING}" OR "$ENV{ENABLE_GCOV_COVERAGE}")
add_compile_options(-UNDEBUG -O0 -g)
set(CMAKE_BUILD_TYPE "Debug")
endif()

# coverage options
if($ENV{ENABLE_GCOV_COVERAGE})
string(APPEND CMAKE_CXX_FLAGS_DEBUG " --coverage")
string(APPEND CMAKE_EXE_LINKER_FLAGS_DEBUG " --coverage")
endif()

# SWIG
option(ENABLE_SWIG "Build swig/python library?" ON)
if(ENABLE_SWIG)
Expand Down
2 changes: 1 addition & 1 deletion models/model_calvetti/dwdx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace amici {

namespace model_model_calvetti{

void dwdx_model_calvetti(realtype *dwdx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl, const realtype *spl) {
void dwdx_model_calvetti(realtype *dwdx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl, const realtype *spl, bool include_static) {
dwdx[0] = 1.0/(x[0]*x[0]*x[0])*-2.0;
dwdx[1] = k[1]*w[15]*dwdx[0];
dwdx[2] = dwdx[1];
Expand Down
59 changes: 28 additions & 31 deletions models/model_calvetti/main.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#include <iostream>

#include <amici/amici.h> /* AMICI base functions */
#include "wrapfunctions.h" /* model-provided functions */
#include "wrapfunctions.h" /* model-provided functions */
#include <amici/amici.h> /* AMICI base functions */

template < class T >
std::ostream& operator << (std::ostream& os, const std::vector<T>& v)
{
template <class T>
std::ostream& operator<<(std::ostream& os, std::vector<T> const& v) {
os << "[";
for (typename std::vector<T>::const_iterator ii = v.begin(); ii != v.end(); ++ii)
{
for (typename std::vector<T>::const_iterator ii = v.begin(); ii != v.end();
++ii) {
os << " " << *ii;
}
os << "]";
Expand All @@ -21,9 +20,9 @@ std::ostream& operator << (std::ostream& os, const std::vector<T>& v)
*/

int main() {
std::cout<<"********************************"<<std::endl;
std::cout<<"** Running forward simulation **"<<std::endl;
std::cout<<"********************************"<<std::endl<<std::endl;
std::cout << "********************************" << std::endl;
std::cout << "** Running forward simulation **" << std::endl;
std::cout << "********************************" << std::endl << std::endl;

// Create a model instance
auto model = amici::generic_model::getModel();
Expand All @@ -44,21 +43,20 @@ int main() {

// Print observable time course
auto observable_ids = model->getObservableIds();
std::cout<<"Simulated observables for timepoints "<<rdata->ts<<"\n\n";
for(int i_observable = 0; i_observable < rdata->ny; ++i_observable) {
std::cout<<observable_ids[i_observable]<<":\n\t";
for(int i_time = 0; i_time < rdata->nt; ++i_time) {
std::cout << "Simulated observables for timepoints " << rdata->ts << "\n\n";
for (int i_observable = 0; i_observable < rdata->ny; ++i_observable) {
std::cout << observable_ids[i_observable] << ":\n\t";
for (int i_time = 0; i_time < rdata->nt; ++i_time) {
// rdata->y is a flat 2D array in row-major ordering
std::cout<<rdata->y[i_time * rdata->ny + i_observable]<<" ";
std::cout << rdata->y[i_time * rdata->ny + i_observable] << " ";
}
std::cout<<std::endl<<std::endl;
std::cout << std::endl << std::endl;
}


std::cout<<std::endl;
std::cout<<"**********************************"<<std::endl;
std::cout<<"** Forward sensitivity analysis **"<<std::endl;
std::cout<<"**********************************"<<std::endl<<std::endl;
std::cout << std::endl;
std::cout << "**********************************" << std::endl;
std::cout << "** Forward sensitivity analysis **" << std::endl;
std::cout << "**********************************" << std::endl << std::endl;

// Enable first-order sensitivity analysis
solver->setSensitivityOrder(amici::SensitivityOrder::first);
Expand All @@ -78,18 +76,17 @@ int main() {
auto state_ids = model->getStateIds();
auto parameter_ids = model->getParameterIds();

std::cout<<"State sensitivities for timepoint "
<<rdata->ts[i_time]
<<std::endl;//nt x nplist x nx
for(int i_state= 0; i_state < rdata->nx; ++i_state) {
std::cout<<"\td("<<state_ids[i_state]<<")/d("
<<parameter_ids[model->plist(i_nplist)]<<") = ";
std::cout << "State sensitivities for timepoint " << rdata->ts[i_time]
<< std::endl; // nt x nplist x nx
for (int i_state = 0; i_state < rdata->nx; ++i_state) {
std::cout << "\td(" << state_ids[i_state] << ")/d("
<< parameter_ids[model->plist(i_nplist)] << ") = ";

// rdata->sx is a flat 3D array in row-major ordering
std::cout<<rdata->sx[i_time * rdata->nplist * rdata->nx
+ i_nplist * rdata->nx
+ i_state];
std::cout<<std::endl;
std::cout << rdata->sx
[i_time * rdata->nplist * rdata->nx
+ i_nplist * rdata->nx + i_state];
std::cout << std::endl;
}

return 0;
Expand Down
Loading

0 comments on commit eb3fd4a

Please sign in to comment.