Skip to content

Commit

Permalink
template parameter fixes, reversed order of functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rok-cesnovar committed May 22, 2020
1 parent 94fdd3e commit 1b2b066
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 73 deletions.
104 changes: 52 additions & 52 deletions stan/math/prim/functor/ode_rk45.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,6 @@
namespace stan {
namespace math {

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in Boost
* with a relative tolerance of 1e-10, an absolute tolerance of 1e-10, and
* taking a maximum of 1e8 steps.
*
* If the system of equations is stiff, <code>ode_bdf</code> will likely be
* faster.
*
* \p f must define an operator() with the signature as:
* template<typename T_t, typename T_y, typename... T_Args>
* Eigen::Matrix<stan::return_type_t<T_t, T_y, T_Args...>, Eigen::Dynamic, 1>
* operator()(const T_t& t, const Eigen::Matrix<T_y, Eigen::Dynamic, 1>& y,
* std::ostream* msgs, const T_Args&... args);
*
* t is the time, y is the state, msgs is a stream for error messages, and args
* are optional arguments passed to the ODE solve function (which are passed
* through to \p f without modification).
*
* @tparam F Type of ODE right hand side
* @tparam T_0 Type of initial time
* @tparam T_ts Type of output times
* @tparam T_Args Types of pass-through parameters
*
* @param f Right hand side of the ODE
* @param y0 Initial state
* @param t0 Initial time
* @param ts Times at which to solve the ODE at. All values must be sorted and
* not less than t0.
* @param relative_tolerance Relative tolerance passed to CVODES
* @param absolute_tolerance Absolute tolerance passed to CVODES
* @param max_num_steps Upper limit on the number of integration steps to
* take between each output (error if exceeded)
* @param[in, out] msgs the print stream for warning messages
* @param args Extra arguments passed unmodified through to ODE right hand side
* @return Solution to ODE at times \p ts
*/
template <typename F, typename T_initial, typename T_t0, typename T_ts,
typename... Args>
std::vector<
Eigen::Matrix<stan::return_type_t<T_initial, Args...>, Eigen::Dynamic, 1>>
ode_rk45(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
double t0, const std::vector<double>& ts, std::ostream* msgs,
const Args&... args) {
double relative_tolerance = 1e-10;
double absolute_tolerance = 1e-10;
long int max_num_steps = 1e8;

return ode_rk45_tol(f, y0, t0, ts, relative_tolerance, absolute_tolerance,
max_num_steps, msgs, args...);
}

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in Boost
Expand Down Expand Up @@ -192,6 +140,58 @@ ode_rk45_tol(const F& f,
return y;
}

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in Boost
* with a relative tolerance of 1e-10, an absolute tolerance of 1e-10, and
* taking a maximum of 1e8 steps.
*
* If the system of equations is stiff, <code>ode_bdf</code> will likely be
* faster.
*
* \p f must define an operator() with the signature as:
* template<typename T_t, typename T_y, typename... T_Args>
* Eigen::Matrix<stan::return_type_t<T_t, T_y, T_Args...>, Eigen::Dynamic, 1>
* operator()(const T_t& t, const Eigen::Matrix<T_y, Eigen::Dynamic, 1>& y,
* std::ostream* msgs, const T_Args&... args);
*
* t is the time, y is the state, msgs is a stream for error messages, and args
* are optional arguments passed to the ODE solve function (which are passed
* through to \p f without modification).
*
* @tparam F Type of ODE right hand side
* @tparam T_0 Type of initial time
* @tparam T_ts Type of output times
* @tparam T_Args Types of pass-through parameters
*
* @param f Right hand side of the ODE
* @param y0 Initial state
* @param t0 Initial time
* @param ts Times at which to solve the ODE at. All values must be sorted and
* not less than t0.
* @param relative_tolerance Relative tolerance passed to CVODES
* @param absolute_tolerance Absolute tolerance passed to CVODES
* @param max_num_steps Upper limit on the number of integration steps to
* take between each output (error if exceeded)
* @param[in, out] msgs the print stream for warning messages
* @param args Extra arguments passed unmodified through to ODE right hand side
* @return Solution to ODE at times \p ts
*/
template <typename F, typename T_initial, typename T_t0, typename T_ts,
typename... Args>
std::vector<
Eigen::Matrix<stan::return_type_t<T_initial, Args...>, Eigen::Dynamic, 1>>
ode_rk45(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
T_t0 t0, const std::vector<T_ts>& ts, std::ostream* msgs,
const Args&... args) {
double relative_tolerance = 1e-10;
double absolute_tolerance = 1e-10;
long int max_num_steps = 1e8;

return ode_rk45_tol(f, y0, t0, ts, relative_tolerance, absolute_tolerance,
max_num_steps, msgs, args...);
}

} // namespace math
} // namespace stan
#endif
42 changes: 21 additions & 21 deletions stan/math/rev/functor/ode_bdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace math {
/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the stiff backward differentiation formula
* (BDF) solver in CVODES with a relative tolerance of 1e-10, an absolute
* tolerance of 1e-10, and taking a maximum of 1e8 steps.
* BDF solver from CVODES.
*
* \p f must define an operator() with the signature as:
* template<typename T_t, typename T_y, typename... T_Args>
Expand Down Expand Up @@ -47,21 +46,24 @@ template <typename F, typename T_initial, typename T_t0, typename T_ts,
typename... T_Args>
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, T_t0, T_ts, T_Args...>,
Eigen::Dynamic, 1>>
ode_bdf(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
const T_t0& t0, const std::vector<T_ts>& ts, std::ostream* msgs,
const T_Args&... args) {
double relative_tolerance = 1e-10;
double absolute_tolerance = 1e-10;
long int max_num_steps = 1e8;
ode_bdf_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
const T_t0& t0, const std::vector<T_ts>& ts,
double relative_tolerance, double absolute_tolerance,
long int max_num_steps, std::ostream* msgs, const T_Args&... args) {
auto integrator
= new stan::math::cvodes_integrator_vari<CV_BDF, F, T_initial, T_t0, T_ts,
T_Args...>(
f, y0, t0, ts, relative_tolerance, absolute_tolerance, max_num_steps,
msgs, args...);

return ode_bdf_tol(f, y0, t0, ts, relative_tolerance, absolute_tolerance,
max_num_steps, msgs, args...);
return (*integrator)();
}

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the stiff backward differentiation formula
* BDF solver from CVODES.
* (BDF) solver in CVODES with a relative tolerance of 1e-10, an absolute
* tolerance of 1e-10, and taking a maximum of 1e8 steps.
*
* \p f must define an operator() with the signature as:
* template<typename T_t, typename T_y, typename... T_Args>
Expand Down Expand Up @@ -95,17 +97,15 @@ template <typename F, typename T_initial, typename T_t0, typename T_ts,
typename... T_Args>
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, T_t0, T_ts, T_Args...>,
Eigen::Dynamic, 1>>
ode_bdf_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
const T_t0& t0, const std::vector<T_ts>& ts,
double relative_tolerance, double absolute_tolerance,
long int max_num_steps, std::ostream* msgs, const T_Args&... args) {
auto integrator
= new stan::math::cvodes_integrator_vari<CV_BDF, F, T_initial, T_t0, T_ts,
T_Args...>(
f, y0, t0, ts, relative_tolerance, absolute_tolerance, max_num_steps,
msgs, args...);
ode_bdf(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
const T_t0& t0, const std::vector<T_ts>& ts, std::ostream* msgs,
const T_Args&... args) {
double relative_tolerance = 1e-10;
double absolute_tolerance = 1e-10;
long int max_num_steps = 1e8;

return (*integrator)();
return ode_bdf_tol(f, y0, t0, ts, relative_tolerance, absolute_tolerance,
max_num_steps, msgs, args...);
}

} // namespace math
Expand Down

0 comments on commit 1b2b066

Please sign in to comment.