Skip to content

Commit

Permalink
Merge pull request #1905 from stan-dev/feature/adjoint-odes
Browse files Browse the repository at this point in the history
Feature/adjoint odes implements  #2486
  • Loading branch information
wds15 committed May 17, 2021
2 parents 1c31a77 + 01c9fa1 commit 5534e22
Show file tree
Hide file tree
Showing 23 changed files with 1,536 additions and 78 deletions.
2 changes: 1 addition & 1 deletion make/tests
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ $(EXPRESSION_TESTS) : $(LIBSUNDIALS)
# CVODES tests
##

CVODES_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*cvodes*_test.cpp) $(call findfiles,test,*_bdf_*_test.cpp) $(call findfiles,test,*_adams_*_test.cpp) $(call findfiles,test,*_ode_typed_*test.cpp))
CVODES_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*cvodes*_test.cpp) $(call findfiles,test,*_bdf_*_test.cpp) $(call findfiles,test,*_adams_*_test.cpp) $(call findfiles,test,*_ode_typed_*test.cpp) $(call findfiles,test,*_ode_adjoint_typed_*test.cpp))
$(CVODES_TESTS) : $(LIBSUNDIALS)


Expand Down
44 changes: 12 additions & 32 deletions stan/math/rev/core/zero_adjoints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,10 @@
namespace stan {
namespace math {

inline void zero_adjoints();

template <typename T, typename... Pargs, require_st_arithmetic<T>* = nullptr>
inline void zero_adjoints(T& x, Pargs&... args);

template <typename... Pargs>
inline void zero_adjoints(var& x, Pargs&... args);

template <int R, int C, typename... Pargs>
inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args);

template <typename T, typename... Pargs, require_st_autodiff<T>* = nullptr>
inline void zero_adjoints(std::vector<T>& x, Pargs&... args);

/**
* End of recursion for set_zero_adjoints
*/
inline void zero_adjoints() {}
inline void zero_adjoints() noexcept {}

/**
* Do nothing for non-autodiff arguments. Recursively call zero_adjoints
Expand All @@ -37,10 +23,8 @@ inline void zero_adjoints() {}
* @param x current argument
* @param args rest of arguments to zero
*/
template <typename T, typename... Pargs, require_st_arithmetic<T>*>
inline void zero_adjoints(T& x, Pargs&... args) {
zero_adjoints(args...);
}
template <typename T, require_st_arithmetic<T>* = nullptr>
inline void zero_adjoints(T& x) noexcept {}

/**
* Zero the adjoint of the vari in the first argument. Recursively call
Expand All @@ -52,11 +36,7 @@ inline void zero_adjoints(T& x, Pargs&... args) {
* @param x current argument
* @param args rest of arguments to zero
*/
template <typename... Pargs>
inline void zero_adjoints(var& x, Pargs&... args) {
x.vi_->set_zero_adjoint();
zero_adjoints(args...);
}
inline void zero_adjoints(var& x) { x.adj() = 0; }

/**
* Zero the adjoints of the varis of every var in an Eigen::Matrix
Expand All @@ -68,11 +48,10 @@ inline void zero_adjoints(var& x, Pargs&... args) {
* @param x current argument
* @param args rest of arguments to zero
*/
template <int R, int C, typename... Pargs>
inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args) {
template <typename EigMat, require_eigen_vt<is_autodiff, EigMat>* = nullptr>
inline void zero_adjoints(EigMat& x) {
for (size_t i = 0; i < x.size(); ++i)
x.coeffRef(i).vi_->set_zero_adjoint();
zero_adjoints(args...);
x.coeffRef(i).adj() = 0;
}

/**
Expand All @@ -85,11 +64,12 @@ inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args) {
* @param x current argument
* @param args rest of arguments to zero
*/
template <typename T, typename... Pargs, require_st_autodiff<T>*>
inline void zero_adjoints(std::vector<T>& x, Pargs&... args) {
for (size_t i = 0; i < x.size(); ++i)
template <typename StdVec,
require_std_vector_st<is_autodiff, StdVec>* = nullptr>
inline void zero_adjoints(StdVec& x) {
for (size_t i = 0; i < x.size(); ++i) {
zero_adjoints(x[i]);
zero_adjoints(args...);
}
}

} // namespace math
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <stan/math/rev/functor/integrate_ode_bdf.hpp>
#include <stan/math/rev/functor/ode_adams.hpp>
#include <stan/math/rev/functor/ode_bdf.hpp>
#include <stan/math/rev/functor/ode_adjoint.hpp>
#include <stan/math/rev/functor/ode_store_sensitivities.hpp>
#include <stan/math/rev/functor/jacobian.hpp>
#include <stan/math/rev/functor/kinsol_data.hpp>
Expand Down
10 changes: 7 additions & 3 deletions stan/math/rev/functor/coupled_ode_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/rev/fun/value_of.hpp>
#include <stan/math/prim/functor/for_each.hpp>
#include <stan/math/prim/err.hpp>
#include <stdexcept>
#include <ostream>
Expand Down Expand Up @@ -137,8 +138,10 @@ struct coupled_ode_system_impl<false, F, T_y0, Args...> {

y_adjoints_ = y_vars.adj();

// memset was faster than Eigen setZero
memset(args_adjoints_.data(), 0, sizeof(double) * num_args_vars);
if (args_adjoints_.size() > 0) {
memset(args_adjoints_.data(), 0,
sizeof(double) * args_adjoints_.size());
}

apply(
[&](auto&&... args) {
Expand All @@ -148,7 +151,8 @@ struct coupled_ode_system_impl<false, F, T_y0, Args...> {

// The vars here do not live on the nested stack so must be zero'd
// separately
apply([&](auto&&... args) { zero_adjoints(args...); }, local_args_tuple_);
stan::math::for_each([](auto&& arg) { zero_adjoints(arg); },
local_args_tuple_);

// No need to zero adjoints after last sweep
if (i + 1 < N_) {
Expand Down
7 changes: 5 additions & 2 deletions stan/math/rev/functor/cvodes_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ class cvodes_integrator {
CVodeSetUserData(cvodes_mem, reinterpret_cast<void*>(this)),
"CVodeSetUserData");

cvodes_set_options(cvodes_mem, relative_tolerance_, absolute_tolerance_,
max_num_steps_);
cvodes_set_options(cvodes_mem, max_num_steps_);

check_flag_sundials(CVodeSStolerances(cvodes_mem, relative_tolerance_,
absolute_tolerance_),
"CVodeSStolerances");

check_flag_sundials(CVodeSetLinearSolver(cvodes_mem, LS_, A_),
"CVodeSetLinearSolver");
Expand Down
Loading

0 comments on commit 5534e22

Please sign in to comment.