diff --git a/make/tests b/make/tests index 30b195a3870..6a90618d62f 100644 --- a/make/tests +++ b/make/tests @@ -64,7 +64,7 @@ $(EXPRESSION_TESTS) : $(LIBSUNDIALS) # CVODES tests ## -CVODES_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*cvodes*_test.cpp) $(call findfiles,test,*_bdf_*_test.cpp) $(call findfiles,test,*_adams_*_test.cpp) $(call findfiles,test,*_ode_typed_*test.cpp)) +CVODES_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*cvodes*_test.cpp) $(call findfiles,test,*_bdf_*_test.cpp) $(call findfiles,test,*_adams_*_test.cpp) $(call findfiles,test,*_ode_typed_*test.cpp) $(call findfiles,test,*_ode_adjoint_typed_*test.cpp)) $(CVODES_TESTS) : $(LIBSUNDIALS) diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index 36368d443ee..a05362fc79d 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -8,24 +8,10 @@ namespace stan { namespace math { -inline void zero_adjoints(); - -template * = nullptr> -inline void zero_adjoints(T& x, Pargs&... args); - -template -inline void zero_adjoints(var& x, Pargs&... args); - -template -inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args); - -template * = nullptr> -inline void zero_adjoints(std::vector& x, Pargs&... args); - /** * End of recursion for set_zero_adjoints */ -inline void zero_adjoints() {} +inline void zero_adjoints() noexcept {} /** * Do nothing for non-autodiff arguments. Recursively call zero_adjoints @@ -37,10 +23,8 @@ inline void zero_adjoints() {} * @param x current argument * @param args rest of arguments to zero */ -template *> -inline void zero_adjoints(T& x, Pargs&... args) { - zero_adjoints(args...); -} +template * = nullptr> +inline void zero_adjoints(T& x) noexcept {} /** * Zero the adjoint of the vari in the first argument. Recursively call @@ -52,11 +36,7 @@ inline void zero_adjoints(T& x, Pargs&... args) { * @param x current argument * @param args rest of arguments to zero */ -template -inline void zero_adjoints(var& x, Pargs&... args) { - x.vi_->set_zero_adjoint(); - zero_adjoints(args...); -} +inline void zero_adjoints(var& x) { x.adj() = 0; } /** * Zero the adjoints of the varis of every var in an Eigen::Matrix @@ -68,11 +48,10 @@ inline void zero_adjoints(var& x, Pargs&... args) { * @param x current argument * @param args rest of arguments to zero */ -template -inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args) { +template * = nullptr> +inline void zero_adjoints(EigMat& x) { for (size_t i = 0; i < x.size(); ++i) - x.coeffRef(i).vi_->set_zero_adjoint(); - zero_adjoints(args...); + x.coeffRef(i).adj() = 0; } /** @@ -85,11 +64,12 @@ inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args) { * @param x current argument * @param args rest of arguments to zero */ -template *> -inline void zero_adjoints(std::vector& x, Pargs&... args) { - for (size_t i = 0; i < x.size(); ++i) +template * = nullptr> +inline void zero_adjoints(StdVec& x) { + for (size_t i = 0; i < x.size(); ++i) { zero_adjoints(x[i]); - zero_adjoints(args...); + } } } // namespace math diff --git a/stan/math/rev/functor.hpp b/stan/math/rev/functor.hpp index 265d4b3a364..4e36d2e7b14 100644 --- a/stan/math/rev/functor.hpp +++ b/stan/math/rev/functor.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/functor/coupled_ode_system.hpp b/stan/math/rev/functor/coupled_ode_system.hpp index 83783d4d03f..1f043fe0e9a 100644 --- a/stan/math/rev/functor/coupled_ode_system.hpp +++ b/stan/math/rev/functor/coupled_ode_system.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -137,8 +138,10 @@ struct coupled_ode_system_impl { y_adjoints_ = y_vars.adj(); - // memset was faster than Eigen setZero - memset(args_adjoints_.data(), 0, sizeof(double) * num_args_vars); + if (args_adjoints_.size() > 0) { + memset(args_adjoints_.data(), 0, + sizeof(double) * args_adjoints_.size()); + } apply( [&](auto&&... args) { @@ -148,7 +151,8 @@ struct coupled_ode_system_impl { // The vars here do not live on the nested stack so must be zero'd // separately - apply([&](auto&&... args) { zero_adjoints(args...); }, local_args_tuple_); + stan::math::for_each([](auto&& arg) { zero_adjoints(arg); }, + local_args_tuple_); // No need to zero adjoints after last sweep if (i + 1 < N_) { diff --git a/stan/math/rev/functor/cvodes_integrator.hpp b/stan/math/rev/functor/cvodes_integrator.hpp index 4e40354540b..554b77e1f54 100644 --- a/stan/math/rev/functor/cvodes_integrator.hpp +++ b/stan/math/rev/functor/cvodes_integrator.hpp @@ -279,8 +279,11 @@ class cvodes_integrator { CVodeSetUserData(cvodes_mem, reinterpret_cast(this)), "CVodeSetUserData"); - cvodes_set_options(cvodes_mem, relative_tolerance_, absolute_tolerance_, - max_num_steps_); + cvodes_set_options(cvodes_mem, max_num_steps_); + + check_flag_sundials(CVodeSStolerances(cvodes_mem, relative_tolerance_, + absolute_tolerance_), + "CVodeSStolerances"); check_flag_sundials(CVodeSetLinearSolver(cvodes_mem, LS_, A_), "CVodeSetLinearSolver"); diff --git a/stan/math/rev/functor/cvodes_integrator_adjoint.hpp b/stan/math/rev/functor/cvodes_integrator_adjoint.hpp new file mode 100644 index 00000000000..d470c39f495 --- /dev/null +++ b/stan/math/rev/functor/cvodes_integrator_adjoint.hpp @@ -0,0 +1,764 @@ +#ifndef STAN_MATH_REV_FUNCTOR_CVODES_INTEGRATOR_ADJOINT_HPP +#define STAN_MATH_REV_FUNCTOR_CVODES_INTEGRATOR_ADJOINT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Integrator interface for CVODES' adjoint ODE solvers (Adams & BDF + * methods). + * + * @tparam F Type of ODE right hand side + * @tparam T_y0 Type of scalars for initial state + * @tparam T_t0 Type of initial time + * @tparam T_ts Type of time-points where ODE solution is returned + * @tparam T_Args Types of pass-through parameters + */ +template +class cvodes_integrator_adjoint_vari : public vari_base { + using T_Return = return_type_t; + using T_y0_t0 = return_type_t; + + static constexpr bool is_var_ts_{is_var::value}; + static constexpr bool is_var_t0_{is_var::value}; + static constexpr bool is_var_y0_{is_var::value}; + static constexpr bool is_var_y0_t0_{is_var::value}; + static constexpr bool is_any_var_args_{ + disjunction>...>::value}; + static constexpr bool is_var_return_{is_var::value}; + static constexpr bool is_var_only_ts_{ + is_var_ts_ && !(is_var_t0_ || is_var_y0_t0_ || is_any_var_args_)}; + + arena_t> y_; + arena_t> ts_; + arena_t> y0_; + arena_t absolute_tolerance_forward_; + arena_t absolute_tolerance_backward_; + arena_t state_forward_; + arena_t state_backward_; + size_t num_args_vars_; + arena_t quad_; + arena_t t0_; + + double relative_tolerance_forward_; + double relative_tolerance_backward_; + double relative_tolerance_quadrature_; + double absolute_tolerance_quadrature_; + long int max_num_steps_; // NOLINT(runtime/int) + long int num_steps_between_checkpoints_; // NOLINT(runtime/int) + size_t N_; + std::ostream* msgs_; + vari** args_varis_; + int interpolation_polynomial_; + int solver_forward_; + int solver_backward_; + int index_backward_; + bool backward_is_initialized_{false}; + + /** + * Since the CVODES solver manages memory with malloc calls, these resources + * must be freed using a destructor call (which is not being called for the + * vari class). + */ + struct cvodes_solver : public chainable_alloc { + const std::string function_name_str_; + const std::decay_t f_; + const size_t N_; + N_Vector nv_state_forward_; + N_Vector nv_state_backward_; + N_Vector nv_quad_; + N_Vector nv_absolute_tolerance_forward_; + N_Vector nv_absolute_tolerance_backward_; + SUNMatrix A_forward_; + SUNLinearSolver LS_forward_; + SUNMatrix A_backward_; + SUNLinearSolver LS_backward_; + void* cvodes_mem_; + std::vector> y_return_; + std::tuple local_args_tuple_; + const std::tuple< + promote_scalar_t>, T_Args>...> + value_of_args_tuple_; + + template + cvodes_solver(const char* function_name, FF&& f, size_t N, + size_t num_args_vars, size_t ts_size, int solver_forward, + StateFwd& state_forward, StateBwd& state_backward, Quad& quad, + AbsTolFwd& absolute_tolerance_forward, + AbsTolBwd& absolute_tolerance_backward, const T_Args&... args) + : chainable_alloc(), + f_(std::forward(f)), + function_name_str_(function_name), + y_return_(ts_size), + nv_state_forward_(N_VMake_Serial(N, state_forward.data())), + nv_state_backward_(N_VMake_Serial(N, state_backward.data())), + nv_quad_(N_VMake_Serial(num_args_vars, quad.data())), + nv_absolute_tolerance_forward_( + N_VMake_Serial(N, absolute_tolerance_forward.data())), + nv_absolute_tolerance_backward_( + N_VMake_Serial(N, absolute_tolerance_backward.data())), + A_forward_(SUNDenseMatrix(N, N)), + A_backward_(SUNDenseMatrix(N, N)), + LS_forward_( + N == 0 ? nullptr + : SUNDenseLinearSolver(nv_state_forward_, A_forward_)), + LS_backward_( + N == 0 ? nullptr + : SUNDenseLinearSolver(nv_state_backward_, A_backward_)), + N_(N), + cvodes_mem_(CVodeCreate(solver_forward)), + local_args_tuple_(deep_copy_vars(args)...), + value_of_args_tuple_(value_of(args)...) { + if (cvodes_mem_ == nullptr) { + throw std::runtime_error("CVodeCreate failed to allocate memory"); + } + } + + virtual ~cvodes_solver() { + SUNMatDestroy(A_forward_); + SUNMatDestroy(A_backward_); + if (N_ > 0) { + SUNLinSolFree(LS_forward_); + SUNLinSolFree(LS_backward_); + } + N_VDestroy_Serial(nv_state_forward_); + N_VDestroy_Serial(nv_state_backward_); + N_VDestroy_Serial(nv_quad_); + N_VDestroy_Serial(nv_absolute_tolerance_forward_); + N_VDestroy_Serial(nv_absolute_tolerance_backward_); + + CVodeFree(&cvodes_mem_); + } + }; + cvodes_solver* solver_{nullptr}; + + public: + /** + * Construct cvodes_integrator object. Note: All arguments must be stored as + * copies if in doubt. The reason is that the references can go out of scope, + * since the work done from the integrator is in the chain method. + * + * @param function_name Calling function name (for printing debugging + * messages) + * @param f Right hand side of the ODE + * @param y0 Initial state + * @param t0 Initial time + * @param ts Times at which to solve the ODE at. All values must be sorted and + * not less than t0. + * @param relative_tolerance_forward Relative tolerance for forward problem + * passed to CVODES + * @param absolute_tolerance_forward Absolute tolerance per ODE state for + * forward problem passed to CVODES + * @param relative_tolerance_backward Relative tolerance for backward problem + * passed to CVODES + * @param absolute_tolerance_backward Absolute tolerance per ODE state for + * backward problem passed to CVODES + * @param relative_tolerance_quadrature Relative tolerance for quadrature + * problem passed to CVODES + * @param absolute_tolerance_quadrature Absolute tolerance for quadrature + * problem passed to CVODES + * @param max_num_steps Upper limit on the number of integration steps to + * take between each output (error if exceeded) + * @param num_steps_between_checkpoints Number of integrator steps after which + * a checkpoint is stored for the backward pass + * @param interpolation_polynomial type of polynomial used for interpolation + * @param solver_forward solver used for forward pass + * @param solver_backward solver used for backward pass + * take between each output (error if exceeded) + * @param[in, out] msgs the print stream for warning messages + * @param args Extra arguments passed unmodified through to ODE right hand + * side function + * @return Solution to ODE at times \p ts + * @return a vector of states, each state being a vector of the + * same size as the state variable, corresponding to a time in ts. + */ + template * = nullptr> + cvodes_integrator_adjoint_vari( + const char* function_name, FF&& f, const T_y0& y0, const T_t0& t0, + const std::vector& ts, double relative_tolerance_forward, + const Eigen::VectorXd& absolute_tolerance_forward, + double relative_tolerance_backward, + const Eigen::VectorXd& absolute_tolerance_backward, + double relative_tolerance_quadrature, + double absolute_tolerance_quadrature, + long int max_num_steps, // NOLINT(runtime/int) + long int num_steps_between_checkpoints, // NOLINT(runtime/int) + int interpolation_polynomial, int solver_forward, int solver_backward, + std::ostream* msgs, const T_Args&... args) + : vari_base(), + y_(ts.size()), + ts_(ts.begin(), ts.end()), + y0_(y0), + absolute_tolerance_forward_(absolute_tolerance_forward), + absolute_tolerance_backward_(absolute_tolerance_backward), + state_forward_(value_of(y0)), + state_backward_(y0.size()), + num_args_vars_(count_vars(args...)), + quad_(num_args_vars_), + t0_(t0), + relative_tolerance_forward_(relative_tolerance_forward), + relative_tolerance_backward_(relative_tolerance_backward), + relative_tolerance_quadrature_(relative_tolerance_quadrature), + absolute_tolerance_quadrature_(absolute_tolerance_quadrature), + max_num_steps_(max_num_steps), + num_steps_between_checkpoints_(num_steps_between_checkpoints), + N_(y0.size()), + msgs_(msgs), + args_varis_([&args..., num_vars = this->num_args_vars_]() { + vari** vari_mem + = ChainableStack::instance_->memalloc_.alloc_array( + num_vars); + save_varis(vari_mem, args...); + return vari_mem; + }()), + interpolation_polynomial_(interpolation_polynomial), + solver_forward_(solver_forward), + solver_backward_(solver_backward), + backward_is_initialized_(false), + solver_(nullptr) { + check_finite(function_name, "initial state", y0); + check_finite(function_name, "initial time", t0); + check_finite(function_name, "times", ts); + + check_nonzero_size(function_name, "times", ts); + check_nonzero_size(function_name, "initial state", y0); + check_sorted(function_name, "times", ts); + check_less(function_name, "initial time", t0, ts[0]); + check_positive_finite(function_name, "relative_tolerance_forward", + relative_tolerance_forward_); + check_positive_finite(function_name, "absolute_tolerance_forward", + absolute_tolerance_forward_); + check_size_match(function_name, "absolute_tolerance_forward", + absolute_tolerance_forward_.size(), "states", N_); + check_positive_finite(function_name, "relative_tolerance_backward", + relative_tolerance_backward_); + check_positive_finite(function_name, "absolute_tolerance_backward", + absolute_tolerance_backward_); + check_size_match(function_name, "absolute_tolerance_backward", + absolute_tolerance_backward_.size(), "states", N_); + check_positive_finite(function_name, "relative_tolerance_quadrature", + relative_tolerance_quadrature_); + check_positive_finite(function_name, "absolute_tolerance_quadrature", + absolute_tolerance_quadrature_); + check_positive(function_name, "max_num_steps", max_num_steps_); + check_positive(function_name, "num_steps_between_checkpoints", + num_steps_between_checkpoints_); + // for polynomial: 1=CV_HERMITE / 2=CV_POLYNOMIAL + if (interpolation_polynomial_ != 1 && interpolation_polynomial_ != 2) + invalid_argument(function_name, "interpolation_polynomial", + interpolation_polynomial_, "", + ", must be 1 for Hermite or 2 for polynomial " + "interpolation of ODE solution"); + // 1=Adams=CV_ADAMS, 2=BDF=CV_BDF + if (solver_forward_ != 1 && solver_forward_ != 2) + invalid_argument(function_name, "solver_forward", solver_forward_, "", + ", must be 1 for Adams or 2 for BDF forward solver"); + if (solver_backward_ != 1 && solver_backward_ != 2) + invalid_argument(function_name, "solver_backward", solver_backward_, "", + ", must be 1 for Adams or 2 for BDF backward solver"); + + solver_ = new cvodes_solver( + function_name, f, N_, num_args_vars_, ts_.size(), solver_forward_, + state_forward_, state_backward_, quad_, absolute_tolerance_forward_, + absolute_tolerance_backward_, args...); + + stan::math::for_each( + [func_name = function_name](auto&& arg) { + check_finite(func_name, "ode parameters and data", arg); + }, + solver_->local_args_tuple_); + + check_flag_sundials( + CVodeInit(solver_->cvodes_mem_, &cvodes_integrator_adjoint_vari::cv_rhs, + value_of(t0_), solver_->nv_state_forward_), + "CVodeInit"); + + // Assign pointer to this as user data + check_flag_sundials( + CVodeSetUserData(solver_->cvodes_mem_, reinterpret_cast(this)), + "CVodeSetUserData"); + + cvodes_set_options(solver_->cvodes_mem_, max_num_steps_); + + check_flag_sundials( + CVodeSVtolerances(solver_->cvodes_mem_, relative_tolerance_forward_, + solver_->nv_absolute_tolerance_forward_), + "CVodeSVtolerances"); + + check_flag_sundials( + CVodeSetLinearSolver(solver_->cvodes_mem_, solver_->LS_forward_, + solver_->A_forward_), + "CVodeSetLinearSolver"); + + check_flag_sundials( + CVodeSetJacFn(solver_->cvodes_mem_, + &cvodes_integrator_adjoint_vari::cv_jacobian_rhs_states), + "CVodeSetJacFn"); + + // initialize backward sensitivity system of CVODES as needed + if (is_var_return_ && !is_var_only_ts_) { + check_flag_sundials( + CVodeAdjInit(solver_->cvodes_mem_, num_steps_between_checkpoints_, + interpolation_polynomial_), + "CVodeAdjInit"); + } + + /** + * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of + * times, { t1, t2, t3, ... } using the requested forward solver of CVODES. + */ + const auto ts_dbl = value_of(ts_); + + double t_init = value_of(t0_); + for (size_t n = 0; n < ts_dbl.size(); ++n) { + double t_final = ts_dbl[n]; + if (t_final != t_init) { + if (is_var_return_ && !is_var_only_ts_) { + int ncheck; + + int error_code + = CVodeF(solver_->cvodes_mem_, t_final, + solver_->nv_state_forward_, &t_init, CV_NORMAL, &ncheck); + + if (unlikely(error_code == CV_TOO_MUCH_WORK)) { + throw_domain_error(solver_->function_name_str_.c_str(), "", t_final, + "Failed to integrate to next output time (", + ") in less than max_num_steps steps"); + } else { + check_flag_sundials(error_code, "CVodeF"); + } + } else { + int error_code + = CVode(solver_->cvodes_mem_, t_final, solver_->nv_state_forward_, + &t_init, CV_NORMAL); + + if (unlikely(error_code == CV_TOO_MUCH_WORK)) { + throw_domain_error(solver_->function_name_str_.c_str(), "", t_final, + "Failed to integrate to next output time (", + ") in less than max_num_steps steps"); + } else { + check_flag_sundials(error_code, "CVode"); + } + } + } + store_state(n, state_forward_, solver_->y_return_[n]); + + t_init = t_final; + } + ChainableStack::instance_->var_stack_.push_back(this); + } + + private: + /** + * Overloads which setup the states returned from the forward solve. In case + * the return type is a double only, then no autodiff is needed. In case of + * autodiff then non-chaining varis are setup accordingly. + */ + void store_state(std::size_t n, const Eigen::VectorXd& state, + Eigen::Matrix& state_return) { + y_[n] = state; + state_return.resize(N_); + for (size_t i = 0; i < N_; i++) { + state_return.coeffRef(i) = var(new vari(state.coeff(i), false)); + } + } + + void store_state(std::size_t n, const Eigen::VectorXd& state, + Eigen::Matrix& state_return) { + y_[n] = state; + state_return = state; + } + + public: + /** + * Obtain solution of ODE. + * + * @return std::vector of Eigen::Matrix of the states of the ODE, one for each + * solution time (excluding the initial state) + */ + std::vector> solution() noexcept { + return solver_->y_return_; + } + + /** + * No-op for setting adjoints since this class does not own any adjoints. + */ + void set_zero_adjoint() final{}; + + void chain() final { + if (!is_var_return_) { + return; + } + + // for sensitivities wrt to ts we do not need to run the backward + // integration + if (is_var_ts_) { + Eigen::VectorXd step_sens = Eigen::VectorXd::Zero(N_); + for (int i = 0; i < ts_.size(); ++i) { + for (int j = 0; j < N_; ++j) { + step_sens.coeffRef(j) + += forward_as(solver_->y_return_[i].coeff(j)).adj(); + } + + adjoint_of(ts_[i]) += step_sens.dot( + rhs(value_of(ts_[i]), y_[i], solver_->value_of_args_tuple_)); + step_sens.setZero(); + } + + if (is_var_only_ts_) { + return; + } + } + + state_backward_.setZero(); + quad_.setZero(); + + // At every time step, collect the adjoints from the output + // variables and re-initialize the solver + double t_init = value_of(ts_.back()); + for (int i = ts_.size() - 1; i >= 0; --i) { + // Take in the adjoints from all the output variables at this point + // in time + for (int j = 0; j < N_; ++j) { + state_backward_.coeffRef(j) + += forward_as(solver_->y_return_[i].coeff(j)).adj(); + } + + double t_final = value_of((i > 0) ? ts_[i - 1] : t0_); + if (t_final != t_init) { + if (unlikely(!backward_is_initialized_)) { + check_flag_sundials(CVodeCreateB(solver_->cvodes_mem_, + solver_backward_, &index_backward_), + "CVodeCreateB"); + + check_flag_sundials( + CVodeSetUserDataB(solver_->cvodes_mem_, index_backward_, + reinterpret_cast(this)), + "CVodeSetUserDataB"); + + // initialize CVODES backward machinery. + // the states of the backward problem *are* the adjoints + // of the ode states + check_flag_sundials( + CVodeInitB(solver_->cvodes_mem_, index_backward_, + &cvodes_integrator_adjoint_vari::cv_rhs_adj, t_init, + solver_->nv_state_backward_), + "CVodeInitB"); + + check_flag_sundials( + CVodeSVtolerancesB(solver_->cvodes_mem_, index_backward_, + relative_tolerance_backward_, + solver_->nv_absolute_tolerance_backward_), + "CVodeSVtolerancesB"); + + check_flag_sundials( + CVodeSetMaxNumStepsB(solver_->cvodes_mem_, index_backward_, + max_num_steps_), + "CVodeSetMaxNumStepsB"); + + check_flag_sundials(CVodeSetLinearSolverB( + solver_->cvodes_mem_, index_backward_, + solver_->LS_backward_, solver_->A_backward_), + "CVodeSetLinearSolverB"); + + check_flag_sundials( + CVodeSetJacFnB( + solver_->cvodes_mem_, index_backward_, + &cvodes_integrator_adjoint_vari::cv_jacobian_rhs_adj_states), + "CVodeSetJacFnB"); + + // Allocate space for backwards quadrature needed when + // parameters vary. + if (is_any_var_args_) { + check_flag_sundials( + CVodeQuadInitB(solver_->cvodes_mem_, index_backward_, + &cvodes_integrator_adjoint_vari::cv_quad_rhs_adj, + solver_->nv_quad_), + "CVodeQuadInitB"); + + check_flag_sundials( + CVodeQuadSStolerancesB(solver_->cvodes_mem_, index_backward_, + relative_tolerance_quadrature_, + absolute_tolerance_quadrature_), + "CVodeQuadSStolerancesB"); + + check_flag_sundials(CVodeSetQuadErrConB(solver_->cvodes_mem_, + index_backward_, SUNTRUE), + "CVodeSetQuadErrConB"); + } + + backward_is_initialized_ = true; + } else { + // just re-initialize the solver + check_flag_sundials( + CVodeReInitB(solver_->cvodes_mem_, index_backward_, t_init, + solver_->nv_state_backward_), + "CVodeReInitB"); + + if (is_any_var_args_) { + check_flag_sundials( + CVodeQuadReInitB(solver_->cvodes_mem_, index_backward_, + solver_->nv_quad_), + "CVodeQuadReInitB"); + } + } + + int error_code = CVodeB(solver_->cvodes_mem_, t_final, CV_NORMAL); + + if (unlikely(error_code == CV_TOO_MUCH_WORK)) { + throw_domain_error(solver_->function_name_str_.c_str(), "", t_final, + "Failed to integrate backward to output time (", + ") in less than max_num_steps steps"); + } else { + check_flag_sundials(error_code, "CVodeB"); + } + + // obtain adjoint states and update t_init to time point + // reached of t_final + check_flag_sundials(CVodeGetB(solver_->cvodes_mem_, index_backward_, + &t_init, solver_->nv_state_backward_), + "CVodeGetB"); + + if (is_any_var_args_) { + check_flag_sundials( + CVodeGetQuadB(solver_->cvodes_mem_, index_backward_, &t_init, + solver_->nv_quad_), + "CVodeGetQuadB"); + } + } + } + + if (is_var_t0_) { + adjoint_of(t0_) += -state_backward_.dot( + rhs(t_init, value_of(y0_), solver_->value_of_args_tuple_)); + } + + // After integrating all the way back to t0, we finally have the + // the adjoints we wanted + // These are the dlog_density / d(initial_conditions[s]) adjoints + if (is_var_y0_t0_) { + forward_as>>(y0_).adj() + += state_backward_; + } + + // These are the dlog_density / d(parameters[s]) adjoints + if (is_any_var_args_) { + for (size_t s = 0; s < num_args_vars_; ++s) { + args_varis_[s]->adj_ += quad_.coeff(s); + } + } + } + + private: + /** + * Call the ODE RHS with given tuple. + */ + template + constexpr auto rhs(double t, const yT& y, + const std::tuple& args_tuple) const { + return apply( + [&](auto&&... args) { return solver_->f_(t, y, msgs_, args...); }, + args_tuple); + } + + /** + * Utility to cast user memory pointer passed in from CVODES to actual typed + * object pointer. + */ + constexpr static cvodes_integrator_adjoint_vari* cast_to_self(void* mem) { + return static_cast(mem); + } + + /** + * Calculates the ODE RHS, dy_dt, using the user-supplied functor at + * the given time t and state y. + */ + inline int rhs(double t, const double* y, double*& dy_dt) const { + const Eigen::VectorXd y_vec = Eigen::Map(y, N_); + const Eigen::VectorXd dy_dt_vec + = rhs(t, y_vec, solver_->value_of_args_tuple_); + check_size_match(solver_->function_name_str_.c_str(), "dy_dt", + dy_dt_vec.size(), "states", N_); + Eigen::Map(dy_dt, N_) = dy_dt_vec; + return 0; + } + + /** + * Implements the function of type CVRhsFn which is the user-defined + * ODE RHS passed to CVODES. + */ + constexpr static int cv_rhs(realtype t, N_Vector y, N_Vector ydot, + void* user_data) { + return cast_to_self(user_data)->rhs(t, NV_DATA_S(y), NV_DATA_S(ydot)); + } + + /* + * Calculate the adjoint sensitivity RHS for varying initial conditions + * and parameters + * + * Equation 2.23 in the cvs_guide. + * + * @param[in] t time + * @param[in] y state of the base ODE system + * @param[in] yB state of the adjoint ODE system + * @param[out] yBdot evaluation of adjoint ODE RHS + */ + inline int rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector yBdot) const { + const nested_rev_autodiff nested; + + Eigen::Matrix y_vars( + Eigen::Map(NV_DATA_S(y), N_)); + Eigen::Matrix f_y_t_vars + = rhs(t, y_vars, solver_->value_of_args_tuple_); + check_size_match(solver_->function_name_str_.c_str(), "dy_dt", + f_y_t_vars.size(), "states", N_); + f_y_t_vars.adj() = -Eigen::Map(NV_DATA_S(yB), N_); + grad(); + Eigen::Map(NV_DATA_S(yBdot), N_) = y_vars.adj(); + return 0; + } + + /** + * Implements the function of type CVRhsFnB which is the + * RHS of the backward ODE system. + */ + constexpr static int cv_rhs_adj(realtype t, N_Vector y, N_Vector yB, + N_Vector yBdot, void* user_data) { + return cast_to_self(user_data)->rhs_adj(t, y, yB, yBdot); + } + + /* + * Calculate the RHS for the quadrature part of the adjoint ODE + * problem. + * + * This is the integrand of equation 2.22 in the cvs_guide. + * + * @param[in] t time + * @param[in] y state of the base ODE system + * @param[in] yB state of the adjoint ODE system + * @param[out] qBdot evaluation of adjoint ODE quadrature RHS + */ + inline int quad_rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector qBdot) { + Eigen::Map y_vec(NV_DATA_S(y), N_); + const nested_rev_autodiff nested; + + // The vars here do not live on the nested stack so must be zero'd + // separately + stan::math::for_each([](auto&& arg) { zero_adjoints(arg); }, + solver_->local_args_tuple_); + Eigen::Matrix f_y_t_vars + = rhs(t, y_vec, solver_->local_args_tuple_); + check_size_match(solver_->function_name_str_.c_str(), "dy_dt", + f_y_t_vars.size(), "states", N_); + f_y_t_vars.adj() = -Eigen::Map(NV_DATA_S(yB), N_); + grad(); + apply( + [&qBdot](auto&&... args) { + accumulate_adjoints(NV_DATA_S(qBdot), args...); + }, + solver_->local_args_tuple_); + return 0; + } + + /** + * Implements the function of type CVQuadRhsFnB which is the + * RHS of the backward ODE system's quadrature. + */ + constexpr static int cv_quad_rhs_adj(realtype t, N_Vector y, N_Vector yB, + N_Vector qBdot, void* user_data) { + return cast_to_self(user_data)->quad_rhs_adj(t, y, yB, qBdot); + } + + /** + * Calculates the jacobian of the ODE RHS wrt to its states y at the + * given time-point t and state y. + */ + inline int jacobian_rhs_states(double t, N_Vector y, SUNMatrix J) const { + Eigen::Map Jfy(SM_DATA_D(J), N_, N_); + + nested_rev_autodiff nested; + + Eigen::Matrix y_var( + Eigen::Map(NV_DATA_S(y), N_)); + Eigen::Matrix fy_var + = rhs(t, y_var, solver_->value_of_args_tuple_); + + check_size_match(solver_->function_name_str_.c_str(), "dy_dt", + fy_var.size(), "states", N_); + + grad(fy_var.coeffRef(0).vi_); + Jfy.col(0) = y_var.adj(); + for (int i = 1; i < fy_var.size(); ++i) { + nested.set_zero_all_adjoints(); + grad(fy_var.coeffRef(i).vi_); + Jfy.col(i) = y_var.adj(); + } + Jfy.transposeInPlace(); + return 0; + } + + /** + * Implements the function of type CVDlsJacFn which is the + * user-defined callback for CVODES to calculate the jacobian of the + * ode_rhs wrt to the states y. The jacobian is stored in column + * major format. + */ + constexpr static int cv_jacobian_rhs_states(realtype t, N_Vector y, + N_Vector fy, SUNMatrix J, + void* user_data, N_Vector tmp1, + N_Vector tmp2, N_Vector tmp3) { + return cast_to_self(user_data)->jacobian_rhs_states(t, y, J); + } + + /* + * Calculate the Jacobian of the RHS of the adjoint ODE (see rhs_adj + * below for citation for how this is done) + * + * @param[in] t Time + * @param[in] y State of system + * @param[out] J CVode structure where output is to be stored + */ + inline int jacobian_rhs_adj_states(double t, N_Vector y, SUNMatrix J) const { + // J_adj_y = -1 * transpose(J_y) + int error_code = jacobian_rhs_states(t, y, J); + + Eigen::Map J_adj_y(SM_DATA_D(J), N_, N_); + J_adj_y.transposeInPlace(); + J_adj_y.array() *= -1.0; + return error_code; + } + + /** + * Implements the CVLsJacFnB function for evaluating the jacobian of + * the adjoint problem wrt to the backward states. + */ + constexpr static int cv_jacobian_rhs_adj_states(realtype t, N_Vector y, + N_Vector yB, N_Vector fyB, + SUNMatrix J, void* user_data, + N_Vector tmp1, N_Vector tmp2, + N_Vector tmp3) { + return cast_to_self(user_data)->jacobian_rhs_adj_states(t, y, J); + } +}; // cvodes integrator adjoint vari + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/rev/functor/cvodes_utils.hpp b/stan/math/rev/functor/cvodes_utils.hpp index d747f223959..8ab960b6519 100644 --- a/stan/math/rev/functor/cvodes_utils.hpp +++ b/stan/math/rev/functor/cvodes_utils.hpp @@ -21,16 +21,13 @@ extern "C" inline void cvodes_err_handler(int error_code, const char* module, } } -inline void cvodes_set_options(void* cvodes_mem, double rel_tol, double abs_tol, +inline void cvodes_set_options(void* cvodes_mem, // NOLINTNEXTLINE(runtime/int) long int max_num_steps) { // forward CVode errors to noop error handler CVodeSetErrHandlerFn(cvodes_mem, cvodes_err_handler, nullptr); // Initialize solver parameters - check_flag_sundials(CVodeSStolerances(cvodes_mem, rel_tol, abs_tol), - "CVodeSStolerances"); - check_flag_sundials(CVodeSetMaxNumSteps(cvodes_mem, max_num_steps), "CVodeSetMaxNumSteps"); diff --git a/stan/math/rev/functor/ode_adjoint.hpp b/stan/math/rev/functor/ode_adjoint.hpp new file mode 100644 index 00000000000..9c866032981 --- /dev/null +++ b/stan/math/rev/functor/ode_adjoint.hpp @@ -0,0 +1,265 @@ +#ifndef STAN_MATH_REV_FUNCTOR_ODE_ADJOINT_HPP +#define STAN_MATH_REV_FUNCTOR_ODE_ADJOINT_HPP + +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of + * times, { t1, t2, t3, ... } using the stiff backward differentiation formula + * BDF solver or the non-stiff Adams solver from CVODES. The ODE system is + * integrated using the adjoint sensitivity approach of CVODES. + * + * \p f must define an operator() with the signature as: + * template + * Eigen::Matrix, Eigen::Dynamic, 1> + * operator()(const T_t& t, const Eigen::Matrix& y, + * std::ostream* msgs, const T_Args&... args); + * + * t is the time, y is the state, msgs is a stream for error messages, and args + * are optional arguments passed to the ODE solve function (which are passed + * through to \p f without modification). + * + * @tparam F Type of ODE right hand side + * @tparam T_y0 Type of initial state + * @tparam T_t0 Type of initial time + * @tparam T_ts Type of output times + * @tparam T_Args Types of pass-through parameters + * + * @param function_name Calling function name (for printing debugging messages) + * @param f Right hand side of the ODE + * @param y0 Initial state + * @param t0 Initial time + * @param ts Times at which to solve the ODE at. All values must be sorted and + * not less than t0. + * @param relative_tolerance_forward Relative tolerance for forward problem + * passed to CVODES + * @param absolute_tolerance_forward Absolute tolerance per ODE state for + * forward problem passed to CVODES + * @param relative_tolerance_backward Relative tolerance for backward problem + * passed to CVODES + * @param absolute_tolerance_backward Absolute tolerance per ODE state for + * backward problem passed to CVODES + * @param relative_tolerance_quadrature Relative tolerance for quadrature + * problem passed to CVODES + * @param absolute_tolerance_quadrature Absolute tolerance for quadrature + * problem passed to CVODES + * @param max_num_steps Upper limit on the number of integration steps to + * take between each output (error if exceeded) + * @param num_steps_between_checkpoints Number of integrator steps after which a + * checkpoint is stored for the backward pass + * @param interpolation_polynomial type of polynomial used for interpolation + * @param solver_forward solver used for forward pass + * @param solver_backward solver used for backward pass + * @param[in, out] msgs the print stream for warning messages + * @param args Extra arguments passed unmodified through to ODE right hand side + * @return An `std::vector` of Eigen column vectors with scalars equal to + * the least upper bound of `T_y0`, `T_t0`, `T_ts`, and the lambda's arguments. + * This represents the solution to ODE at times \p ts + */ +template * = nullptr, + require_any_not_st_arithmetic* = nullptr> +auto ode_adjoint_impl( + const char* function_name, F&& f, const T_y0& y0, const T_t0& t0, + const std::vector& ts, double relative_tolerance_forward, + const T_abs_tol_fwd& absolute_tolerance_forward, + double relative_tolerance_backward, + const T_abs_tol_bwd& absolute_tolerance_backward, + double relative_tolerance_quadrature, double absolute_tolerance_quadrature, + long int max_num_steps, // NOLINT(runtime/int) + long int num_steps_between_checkpoints, // NOLINT(runtime/int) + int interpolation_polynomial, int solver_forward, int solver_backward, + std::ostream* msgs, const T_Args&... args) { + using integrator_vari + = cvodes_integrator_adjoint_vari, T_t0, T_ts, + plain_type_t...>; + auto integrator = new integrator_vari( + function_name, std::forward(f), y0, t0, ts, relative_tolerance_forward, + absolute_tolerance_forward, relative_tolerance_backward, + absolute_tolerance_backward, relative_tolerance_quadrature, + absolute_tolerance_quadrature, max_num_steps, + num_steps_between_checkpoints, interpolation_polynomial, solver_forward, + solver_backward, msgs, args...); + return integrator->solution(); +} + +/** + * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of + * times, { t1, t2, t3, ... } using the stiff backward differentiation formula + * BDF solver or the non-stiff Adams solver from CVODES. The ODE system is + * integrated using the adjoint sensitivity approach of CVODES. This + * implementation handles the case of a double return type which ensures that no + * resources are left on the AD stack. + * + * \p f must define an operator() with the signature as: + * template + * Eigen::Matrix, Eigen::Dynamic, 1> + * operator()(const T_t& t, const Eigen::Matrix& y, + * std::ostream* msgs, const T_Args&... args); + * + * t is the time, y is the state, msgs is a stream for error messages, and args + * are optional arguments passed to the ODE solve function (which are passed + * through to \p f without modification). + * + * @tparam F Type of ODE right hand side + * @tparam T_y0 Type of initial state + * @tparam T_t0 Type of initial time + * @tparam T_ts Type of output times + * @tparam T_Args Types of pass-through parameters + * + * @param function_name Calling function name (for printing debugging messages) + * @param f Right hand side of the ODE + * @param y0 Initial state + * @param t0 Initial time + * @param ts Times at which to solve the ODE at. All values must be sorted and + * not less than t0. + * @param relative_tolerance_forward Relative tolerance for forward problem + * passed to CVODES + * @param absolute_tolerance_forward Absolute tolerance per ODE state for + * forward problem passed to CVODES + * @param relative_tolerance_backward Relative tolerance for backward problem + * passed to CVODES + * @param absolute_tolerance_backward Absolute tolerance per ODE state for + * backward problem passed to CVODES + * @param relative_tolerance_quadrature Relative tolerance for quadrature + * problem passed to CVODES + * @param absolute_tolerance_quadrature Absolute tolerance for quadrature + * problem passed to CVODES + * @param max_num_steps Upper limit on the number of integration steps to + * take between each output (error if exceeded) + * @param num_steps_between_checkpoints Number of integrator steps after which a + * checkpoint is stored for the backward pass + * @param interpolation_polynomial type of polynomial used for interpolation + * @param solver_forward solver used for forward pass + * @param solver_backward solver used for backward pass + * @param[in, out] msgs the print stream for warning messages + * @param args Extra arguments passed unmodified through to ODE right hand side + * @return An `std::vector` of Eigen column vectors with scalars equal to + * the least upper bound of `T_y0`, `T_t0`, `T_ts`, and the lambda's arguments. + * This represents the solution to ODE at times \p ts + */ +template * = nullptr, + require_all_st_arithmetic* = nullptr> +std::vector ode_adjoint_impl( + const char* function_name, F&& f, const T_y0& y0, const T_t0& t0, + const std::vector& ts, double relative_tolerance_forward, + const T_abs_tol_fwd& absolute_tolerance_forward, + double relative_tolerance_backward, + const T_abs_tol_bwd& absolute_tolerance_backward, + double relative_tolerance_quadrature, double absolute_tolerance_quadrature, + long int max_num_steps, // NOLINT(runtime/int) + long int num_steps_between_checkpoints, // NOLINT(runtime/int) + int interpolation_polynomial, int solver_forward, int solver_backward, + std::ostream* msgs, const T_Args&... args) { + std::vector ode_solution; + { + nested_rev_autodiff nested; + + using integrator_vari + = cvodes_integrator_adjoint_vari, T_t0, T_ts, + plain_type_t...>; + + auto integrator = new integrator_vari( + function_name, std::forward(f), y0, t0, ts, + relative_tolerance_forward, absolute_tolerance_forward, + relative_tolerance_backward, absolute_tolerance_backward, + relative_tolerance_quadrature, absolute_tolerance_quadrature, + max_num_steps, num_steps_between_checkpoints, interpolation_polynomial, + solver_forward, solver_backward, msgs, args...); + + ode_solution = integrator->solution(); + } + return ode_solution; +} + +/** + * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of + * times, { t1, t2, t3, ... } using the stiff backward differentiation formula + * BDF solver or the non-stiff Adams solver from CVODES. The ODE system is + * integrated using the adjoint sensitivity approach of CVODES. + * + * \p f must define an operator() with the signature as: + * template + * Eigen::Matrix, Eigen::Dynamic, 1> + * operator()(const T_t& t, const Eigen::Matrix& y, + * std::ostream* msgs, const T_Args&... args); + * + * t is the time, y is the state, msgs is a stream for error messages, and args + * are optional arguments passed to the ODE solve function (which are passed + * through to \p f without modification). + * + * @tparam F Type of ODE right hand side + * @tparam T_y0 Type of initial state + * @tparam T_t0 Type of initial time + * @tparam T_ts Type of output times + * @tparam T_Args Types of pass-through parameters + * + * @param f Right hand side of the ODE + * @param y0 Initial state + * @param t0 Initial time + * @param ts Times at which to solve the ODE at. All values must be sorted and + * not less than t0. + * @param relative_tolerance_forward Relative tolerance for forward problem + * passed to CVODES + * @param absolute_tolerance_forward Absolute tolerance per ODE state for + * forward problem passed to CVODES + * @param relative_tolerance_backward Relative tolerance for backward problem + * passed to CVODES + * @param absolute_tolerance_backward Absolute tolerance per ODE state for + * backward problem passed to CVODES + * @param relative_tolerance_quadrature Relative tolerance for quadrature + * problem passed to CVODES + * @param absolute_tolerance_quadrature Absolute tolerance for quadrature + * problem passed to CVODES + * @param max_num_steps Upper limit on the number of integration steps to + * take between each output (error if exceeded) + * @param num_steps_between_checkpoints Number of integrator steps after which a + * checkpoint is stored for the backward pass + * @param interpolation_polynomial type of polynomial used for interpolation + * @param solver_forward solver used for forward pass + * @param solver_backward solver used for backward pass + * @param[in, out] msgs the print stream for warning messages + * @param args Extra arguments passed unmodified through to ODE right hand side + * @return An `std::vector` of Eigen column vectors with scalars equal to + * the least upper bound of `T_y0`, `T_t0`, `T_ts`, and the lambda's arguments. + * This represents the solution to ODE at times \p ts + */ +template * = nullptr> +auto ode_adjoint_tol_ctl( + F&& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, + double relative_tolerance_forward, + const T_abs_tol_fwd& absolute_tolerance_forward, + double relative_tolerance_backward, + const T_abs_tol_bwd& absolute_tolerance_backward, + double relative_tolerance_quadrature, double absolute_tolerance_quadrature, + long int max_num_steps, // NOLINT(runtime/int) + long int num_steps_between_checkpoints, // NOLINT(runtime/int) + int interpolation_polynomial, int solver_forward, int solver_backward, + std::ostream* msgs, const T_Args&... args) { + return ode_adjoint_impl( + "ode_adjoint_tol_ctl", std::forward(f), y0, t0, ts, + relative_tolerance_forward, absolute_tolerance_forward, + relative_tolerance_backward, absolute_tolerance_backward, + relative_tolerance_quadrature, absolute_tolerance_quadrature, + max_num_steps, num_steps_between_checkpoints, interpolation_polynomial, + solver_forward, solver_backward, msgs, args...); +} + +} // namespace math +} // namespace stan +#endif diff --git a/test/unit/math/prim/functor/harmonic_oscillator.hpp b/test/unit/math/prim/functor/harmonic_oscillator.hpp index 458c5c2ca4b..4a0903906f0 100644 --- a/test/unit/math/prim/functor/harmonic_oscillator.hpp +++ b/test/unit/math/prim/functor/harmonic_oscillator.hpp @@ -30,9 +30,8 @@ struct harm_osc_ode_fun { struct harm_osc_ode_fun_eigen { template - inline auto operator()(const T0& t_in, - const Eigen::Matrix& y_in, - std::ostream* msgs, const std::vector& theta, + inline auto operator()(const T0& t_in, const T1& y_in, std::ostream* msgs, + const std::vector& theta, const std::vector& x, const std::vector& x_int) const { if (y_in.size() != 2) @@ -73,9 +72,8 @@ struct harm_osc_ode_data_fun { struct harm_osc_ode_data_fun_eigen { template - inline auto operator()(const T0& t_in, - const Eigen::Matrix& y_in, - std::ostream* msgs, const std::vector& theta, + inline auto operator()(const T0& t_in, const T1& y_in, std::ostream* msgs, + const std::vector& theta, const std::vector& x, const std::vector& x_int) const { if (y_in.size() != 2) diff --git a/test/unit/math/prim/functor/ode_test_functors.hpp b/test/unit/math/prim/functor/ode_test_functors.hpp index 633f04c1634..3de5ef75f0f 100644 --- a/test/unit/math/prim/functor/ode_test_functors.hpp +++ b/test/unit/math/prim/functor/ode_test_functors.hpp @@ -28,8 +28,8 @@ auto sum_(Vec&& arg) { struct CosArg1 { template inline Eigen::Matrix, Eigen::Dynamic, 1> - operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const T_Args&... a) const { + operator()(const T0& t, const T1& y, std::ostream* msgs, + const T_Args&... a) const { std::vector::type> vec = {sum_(a)...}; Eigen::Matrix, Eigen::Dynamic, 1> out(1); @@ -41,8 +41,8 @@ struct CosArg1 { struct Cos2Arg { template inline Eigen::Matrix, Eigen::Dynamic, 1> - operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const T2& a, const T3& b) const { + operator()(const T0& t, const T1& y, std::ostream* msgs, const T2& a, + const T3& b) const { Eigen::Matrix, Eigen::Dynamic, 1> out(1); out << stan::math::cos((sum_(a) + sum_(b)) * t); return out; @@ -52,8 +52,8 @@ struct Cos2Arg { struct CosArgWrongSize { template inline Eigen::Matrix, Eigen::Dynamic, 1> - operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const T_Args&... a) const { + operator()(const T0& t, const T1& y, std::ostream* msgs, + const T_Args&... a) const { std::vector::type> vec = {sum_(a)...}; Eigen::Matrix, Eigen::Dynamic, 1> out(2); diff --git a/test/unit/math/rev/core/zero_adjoints_test.cpp b/test/unit/math/rev/core/zero_adjoints_test.cpp index 07ab008e47b..9f0fbf0c719 100644 --- a/test/unit/math/rev/core/zero_adjoints_test.cpp +++ b/test/unit/math/rev/core/zero_adjoints_test.cpp @@ -1,8 +1,9 @@ -#include -#include #include +#include #include #include +#include +#include TEST(AgradRev, zero_arithmetic) { int a = 1.0; @@ -30,8 +31,9 @@ TEST(AgradRev, zero_arithmetic) { stan::math::zero_adjoints(vc); stan::math::zero_adjoints(vd); stan::math::zero_adjoints(ve); - - stan::math::zero_adjoints(a, b, va, vb, c, d, e, vva, vvb, vc, vd, ve); + stan::math::for_each( + [](auto&& x) { stan::math::zero_adjoints(x); }, + std::forward_as_tuple(a, b, va, vb, c, d, e, vva, vvb, vc, vd, ve)); } TEST(AgradRev, zero_var) { @@ -214,7 +216,8 @@ TEST(AgradRev, zero_multi) { std::vector e(5, 1); std::vector f(5, 1.0); - stan::math::zero_adjoints(a, b, c, d, e, f); + stan::math::for_each([](auto&& x) { stan::math::zero_adjoints(x); }, + std::forward_as_tuple(a, b, c, d, e, f)); EXPECT_FLOAT_EQ(c.vi_->adj_, 0.0); for (size_t i = 0; i < d.size(); ++i) EXPECT_FLOAT_EQ(d[i].vi_->adj_, 0.0); diff --git a/test/unit/math/rev/functor/cos_ode_typed_test.cpp b/test/unit/math/rev/functor/cos_ode_typed_test.cpp index 1327745c2e3..e72f749e277 100644 --- a/test/unit/math/rev/functor/cos_ode_typed_test.cpp +++ b/test/unit/math/rev/functor/cos_ode_typed_test.cpp @@ -16,8 +16,9 @@ using ode_test_tuple = std::tuple; * Outer product of test types */ using cos_arg_test_types = boost::mp11::mp_product< - ode_test_tuple, ::testing::Types >; + ode_test_tuple, + ::testing::Types >; TYPED_TEST_SUITE_P(cos_arg_test); TYPED_TEST_P(cos_arg_test, y0_error) { diff --git a/test/unit/math/rev/functor/fho_ode_typed_ts_test.cpp b/test/unit/math/rev/functor/fho_ode_typed_ts_test.cpp index 17ed94dc23e..f405cad4952 100644 --- a/test/unit/math/rev/functor/fho_ode_typed_ts_test.cpp +++ b/test/unit/math/rev/functor/fho_ode_typed_ts_test.cpp @@ -18,7 +18,7 @@ using ode_test_tuple = std::tuple; using forced_harm_osc_ts_test_types = boost::mp11::mp_product< ode_test_tuple, ::testing::Types, + ode_rk45_functor, ode_adjoint_functor>, ::testing::Types >, // time ::testing::Types >, // y0 ::testing::Types > // theta diff --git a/test/unit/math/rev/functor/lorenz_ode_typed_fd_test.cpp b/test/unit/math/rev/functor/lorenz_ode_typed_fd_test.cpp index e882930d38e..a979aeb0e42 100644 --- a/test/unit/math/rev/functor/lorenz_ode_typed_fd_test.cpp +++ b/test/unit/math/rev/functor/lorenz_ode_typed_fd_test.cpp @@ -16,8 +16,9 @@ using ode_test_tuple = std::tuple; * Outer product of test types */ using lorenz_test_types = boost::mp11::mp_product< - ode_test_tuple, ::testing::Types>; + ode_test_tuple, + ::testing::Types>; TYPED_TEST_SUITE_P(lorenz_test); TYPED_TEST_P(lorenz_test, param_and_data_finite_diff) { diff --git a/test/unit/math/rev/functor/ode_test_functors.hpp b/test/unit/math/rev/functor/ode_test_functors.hpp index 1abf22f1138..997fc4ab075 100644 --- a/test/unit/math/rev/functor/ode_test_functors.hpp +++ b/test/unit/math/rev/functor/ode_test_functors.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #define STAN_DEF_ODE_SOLVER_FUNCTOR(solver_name, solver_func) \ @@ -15,7 +16,7 @@ typename... Args, stan::require_eigen_vector_t* = nullptr> \ std::vector, \ Eigen::Dynamic, 1>> \ - operator()(const F& f, const T_y0& y0, T_t0 t0, \ + operator()(const F& f, const T_y0& y0, const T_t0& t0, \ const std::vector& ts, std::ostream* msgs, \ const Args&... args) { \ return solver_func(f, y0, t0, ts, msgs, args...); \ @@ -25,7 +26,7 @@ typename... Args, stan::require_eigen_vector_t* = nullptr> \ std::vector, \ Eigen::Dynamic, 1>> \ - operator()(const F& f, const T_y0& y0_arg, T_t0 t0, \ + operator()(const F& f, const T_y0& y0_arg, const T_t0& t0, \ const std::vector& ts, double rtol, double atol, \ size_t max_num_steps, std::ostream* msgs, \ const Args&... args) { \ @@ -61,4 +62,73 @@ STAN_DEF_STD_ODE_SOLVER_FUNCTOR(integrate_ode_bdf, STAN_DEF_STD_ODE_SOLVER_FUNCTOR(integrate_ode_rk45, stan::math::integrate_ode_rk45); +struct ode_adjoint_functor { + const std::string functor_name = "ode_adjoint"; + + template * = nullptr> + std::vector, + Eigen::Dynamic, 1>> + operator()(const F& f, const T_y0& y0, const T_t0& t0, + const std::vector& ts, std::ostream* msgs, + const Args&... args) { + return (*this)(f, y0, t0, ts, 1E-10, 1E-10, 1000000, msgs, args...); + } + + template * = nullptr> + std::vector, + Eigen::Dynamic, 1>> + operator()(const F& f, const T_y0& y0_arg, const T_t0& t0, + const std::vector& ts, double relative_tolerance, + double absolute_tolerance, size_t max_num_steps, + std::ostream* msgs, const Args&... args) { + const int N = y0_arg.size(); + const double relative_tolerance_forward = relative_tolerance / 8.0; + const double relative_tolerance_backward = relative_tolerance / 4.0; + const double relative_tolerance_quadrature = relative_tolerance; + const Eigen::VectorXd absolute_tolerance_forward + = Eigen::VectorXd::Constant(N, absolute_tolerance / 6.0); + const Eigen::VectorXd absolute_tolerance_backward + = Eigen::VectorXd::Constant(N, absolute_tolerance / 3.0); + const double absolute_tolerance_quadrature = absolute_tolerance; + const long int num_steps_between_checkpoints = 150; // NOLINT(runtime/int) + const int interpolation_polynomial = CV_HERMITE; + const int solver_forward = CV_BDF; + const int solver_backward = CV_ADAMS; + + return stan::math::ode_adjoint_tol_ctl( + f, y0_arg, t0, ts, relative_tolerance_forward, + absolute_tolerance_forward, relative_tolerance_backward, + absolute_tolerance_backward, relative_tolerance_quadrature, + absolute_tolerance_quadrature, max_num_steps, + num_steps_between_checkpoints, interpolation_polynomial, solver_forward, + solver_backward, msgs, args...); + } + + template * = nullptr> + std::vector, + Eigen::Dynamic, 1>> + operator()(const F& f, const T_y0& y0, const T_t0& t0, + const std::vector& ts, double relative_tolerance_forward, + const Eigen::VectorXd& absolute_tolerance_forward, + double relative_tolerance_backward, + const Eigen::VectorXd& absolute_tolerance_backward, + double relative_tolerance_quadrature, + double absolute_tolerance_quadrature, + long int max_num_steps, // NOLINT(runtime/int) + long int num_steps_between_checkpoints, // NOLINT(runtime/int) + int interpolation_polynomial, int solver_forward, + int solver_backward, std::ostream* msgs, const T_Args&... args) { + return stan::math::ode_adjoint_tol_ctl( + f, y0, t0, ts, relative_tolerance_forward, absolute_tolerance_forward, + relative_tolerance_backward, absolute_tolerance_backward, + relative_tolerance_quadrature, absolute_tolerance_quadrature, + max_num_steps, num_steps_between_checkpoints, interpolation_polynomial, + solver_forward, solver_backward, msgs, args...); + } +}; + #endif diff --git a/test/unit/math/rev/functor/sho_ode_adjoint_typed_error_test.cpp b/test/unit/math/rev/functor/sho_ode_adjoint_typed_error_test.cpp new file mode 100644 index 00000000000..89d21c789ef --- /dev/null +++ b/test/unit/math/rev/functor/sho_ode_adjoint_typed_error_test.cpp @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include +#include + +/** + * + * Use same solver functor type for both w & w/o tolerance control + */ +template +using ode_test_tuple = std::tuple; + +/** + * Outer product of test types + */ +using harmonic_oscillator_ctl_test_types = boost::mp11::mp_product< + ode_test_tuple, ::testing::Types, + ::testing::Types>, // t + ::testing::Types>, // y0 + ::testing::Types> // theta + >; + +TYPED_TEST_SUITE_P(harmonic_oscillator_ctl_test); +TYPED_TEST_P(harmonic_oscillator_ctl_test, no_error) { this->test_good(); } +TYPED_TEST_P(harmonic_oscillator_ctl_test, error_conditions) { + this->test_bad(); +} +TYPED_TEST_P(harmonic_oscillator_ctl_test, value) { + stan::math::nested_rev_autodiff nested; + + this->test_value(0.0); + this->test_value(1.0); + this->test_value(-1.0); + + if (std::is_same, double>::value + && std::is_same, double>::value + && std::is_same, double>::value) { + EXPECT_EQ(stan::math::nested_size(), 0); + } else { + EXPECT_GT(stan::math::nested_size(), 0); + } +} + +REGISTER_TYPED_TEST_SUITE_P(harmonic_oscillator_ctl_test, no_error, + error_conditions, value); +INSTANTIATE_TYPED_TEST_SUITE_P(StanOde, harmonic_oscillator_ctl_test, + harmonic_oscillator_ctl_test_types); diff --git a/test/unit/math/rev/functor/sho_ode_typed_error_test.cpp b/test/unit/math/rev/functor/sho_ode_typed_error_test.cpp index c85ce707482..aebf9b1d216 100644 --- a/test/unit/math/rev/functor/sho_ode_typed_error_test.cpp +++ b/test/unit/math/rev/functor/sho_ode_typed_error_test.cpp @@ -18,7 +18,7 @@ using ode_test_tuple = std::tuple; using harmonic_oscillator_test_types = boost::mp11::mp_product< ode_test_tuple, ::testing::Types, + ode_rk45_functor, ode_adjoint_functor>, ::testing::Types>, // t ::testing::Types>, // y0 ::testing::Types> // theta diff --git a/test/unit/math/rev/functor/sho_ode_typed_test.cpp b/test/unit/math/rev/functor/sho_ode_typed_test.cpp index fc57ce6a489..9d8b56f4855 100644 --- a/test/unit/math/rev/functor/sho_ode_typed_test.cpp +++ b/test/unit/math/rev/functor/sho_ode_typed_test.cpp @@ -19,7 +19,7 @@ using ode_test_tuple = std::tuple; using harmonic_oscillator_fd_test_types = boost::mp11::mp_product< ode_test_tuple, ::testing::Types, + ode_rk45_functor, ode_adjoint_functor>, ::testing::Types, // t ::testing::Types, // y0 ::testing::Types // theta @@ -90,7 +90,7 @@ INSTANTIATE_TYPED_TEST_SUITE_P(StanOde, harmonic_oscillator_data_test, using harmonic_oscillator_test_types = boost::mp11::mp_product< ode_test_tuple, ::testing::Types, + ode_rk45_functor, ode_adjoint_functor>, ::testing::Types, // t ::testing::Types >, // y0 ::testing::Types > // theta @@ -114,6 +114,10 @@ TYPED_TEST_P(harmonic_oscillator_t0_ad_test, t0_ad) { ode_bdf_functor>::value) { this->test_t0_ad(1e-7); } + if (std::is_same, + ode_adjoint_functor>::value) { + this->test_t0_ad(1e-7); + } } REGISTER_TYPED_TEST_SUITE_P(harmonic_oscillator_t0_ad_test, t0_ad); INSTANTIATE_TYPED_TEST_SUITE_P(StanOde, harmonic_oscillator_t0_ad_test, diff --git a/test/unit/math/rev/functor/test_fixture_ode.hpp b/test/unit/math/rev/functor/test_fixture_ode.hpp index 5707b467810..b11fb16b4a2 100644 --- a/test/unit/math/rev/functor/test_fixture_ode.hpp +++ b/test/unit/math/rev/functor/test_fixture_ode.hpp @@ -44,6 +44,7 @@ * - ode_ckrk_functor * - ode_bdf_functor * - ode_rk45_functor + * - ode_adjoint_functor * - integrate_ode_adams_functor * - integrate_ode_bdf_functor * - integrate_ode_rk45_functor @@ -109,6 +110,8 @@ */ template struct ODETestFixture : public ::testing::Test { + virtual void TearDown() { stan::math::recover_memory(); } + /** * test ODE solver pass * diff --git a/test/unit/math/rev/functor/test_fixture_ode_fho.hpp b/test/unit/math/rev/functor/test_fixture_ode_fho.hpp index 9e78a45accf..ac38801a388 100644 --- a/test/unit/math/rev/functor/test_fixture_ode_fho.hpp +++ b/test/unit/math/rev/functor/test_fixture_ode_fho.hpp @@ -17,8 +17,8 @@ struct forced_harm_osc_base { struct forced_harm_osc { template inline Eigen::Matrix, -1, 1> operator()( - const T0& t_in, const Eigen::Matrix& y_in, - std::ostream* msgs, const std::vector& theta) const { + const T0& t_in, const T1& y_in, std::ostream* msgs, + const T2& theta) const { if (y_in.size() != 2) throw std::domain_error( "this function was called with inconsistent state"); diff --git a/test/unit/math/rev/functor/test_fixture_ode_lorenz.hpp b/test/unit/math/rev/functor/test_fixture_ode_lorenz.hpp index c7c1171aed9..a1a50aa0f55 100644 --- a/test/unit/math/rev/functor/test_fixture_ode_lorenz.hpp +++ b/test/unit/math/rev/functor/test_fixture_ode_lorenz.hpp @@ -17,8 +17,8 @@ struct lorenz_ode_base { struct lorenz_rhs { template inline Eigen::Matrix, -1, 1> operator()( - const T0& t_in, const Eigen::Matrix& y_in, - std::ostream* msgs, const std::vector& theta) const { + const T0& t_in, const T1& y_in, std::ostream* msgs, + const T2& theta) const { Eigen::Matrix, -1, 1> res(3); res << theta.at(0) * (y_in(1) - y_in(0)), theta.at(1) * y_in(0) - y_in(1) - y_in(0) * y_in(2), diff --git a/test/unit/math/rev/functor/test_fixture_ode_sho.hpp b/test/unit/math/rev/functor/test_fixture_ode_sho.hpp index 14c25bea139..560d3d375c6 100644 --- a/test/unit/math/rev/functor/test_fixture_ode_sho.hpp +++ b/test/unit/math/rev/functor/test_fixture_ode_sho.hpp @@ -17,11 +17,10 @@ template struct harmonic_oscillator_ode_base { struct sho_square_fun { - template + template inline Eigen::Matrix, -1, 1> operator()( - const T0& t_in, const Eigen::Matrix& y_in, - std::ostream* msgs, const std::vector& theta, - const std::vector& x, const std::vector& x_int) const { + const T0& t_in, const T1& y_in, std::ostream* msgs, const T2& theta, + const T3& x, const T4& x_int) const { if (y_in.size() != 2) throw std::domain_error("Functor called with inconsistent state"); @@ -269,7 +268,6 @@ template struct harmonic_oscillator_t0_ad_test : public harmonic_oscillator_ode_base, public ODETestFixture> { - stan::math::nested_rev_autodiff nested; stan::math::var t0v; harmonic_oscillator_t0_ad_test() @@ -284,19 +282,24 @@ struct harmonic_oscillator_t0_ad_test } void test_t0_ad(double tol) { + stan::math::nested_rev_autodiff nested; auto res = apply_solver(); res[0][0].grad(); EXPECT_NEAR(t0v.adj(), -0.66360742442816977871, tol); nested.set_zero_all_adjoints(); + t0v.adj() = 0.0; res[0][1].grad(); EXPECT_NEAR(t0v.adj(), 0.23542843380353062344, tol); nested.set_zero_all_adjoints(); + t0v.adj() = 0.0; res[1][0].grad(); EXPECT_NEAR(t0v.adj(), -0.2464078910913158893, tol); nested.set_zero_all_adjoints(); + t0v.adj() = 0.0; res[1][1].grad(); EXPECT_NEAR(t0v.adj(), -0.38494826636037426937, tol); nested.set_zero_all_adjoints(); + t0v.adj() = 0.0; } }; diff --git a/test/unit/math/rev/functor/test_fixture_ode_sho_ctl.hpp b/test/unit/math/rev/functor/test_fixture_ode_sho_ctl.hpp new file mode 100644 index 00000000000..1c95d1d25dd --- /dev/null +++ b/test/unit/math/rev/functor/test_fixture_ode_sho_ctl.hpp @@ -0,0 +1,312 @@ +#ifndef STAN_MATH_TEST_FIXTURE_ODE_SHO_CTL_HPP +#define STAN_MATH_TEST_FIXTURE_ODE_SHO_CTL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * Inheriting base type, various fixtures differs by the type of ODE functor + * used in apply_solver calls, intended for different kind of + * tests. This harmonic oscillator test class is intended for use with the + * adjoint ODE solver which has additional control parameters. + * + */ +template +struct harmonic_oscillator_ctl_test + : public harmonic_oscillator_ode_base, + public ODETestFixture> { + double rtol_f; + double rtol_b; + double rtol_q; + Eigen::VectorXd atol_f; + Eigen::VectorXd atol_b; + double atol_q; + int num_steps_check; + int inter_poly; + int solv_f; + int solv_b; + + harmonic_oscillator_ctl_test() + : harmonic_oscillator_ode_base(), + rtol_f(this->rtol / 8.0), + rtol_b(this->rtol / 4.0), + rtol_q(this->rtol), + atol_f(Eigen::VectorXd::Constant(this->y0.size(), this->atol / 6.0)), + atol_b(Eigen::VectorXd::Constant(this->y0.size(), this->atol / 3.0)), + atol_q(this->atol), + num_steps_check(100), + inter_poly(CV_HERMITE), + solv_f(CV_BDF), + solv_b(CV_ADAMS) {} + + auto apply_solver() { + std::tuple_element_t<0, T> sol; + return sol(this->f_eigen, this->y0, this->t0, this->ts, nullptr, + this->theta, this->x_r, this->x_i); + } + + template + auto apply_solver(T1&& init, T2&& theta_in) { + std::tuple_element_t<0, T> sol; + return sol(this->f_eigen, init, this->t0, this->ts, nullptr, theta_in, + this->x_r, this->x_i); + } + + auto apply_solver_tol() { + std::tuple_element_t<1, T> sol; + return sol(this->f_eigen, this->y0, this->t0, this->ts, this->rtol, + this->atol, this->max_num_step, nullptr, this->theta, this->x_r, + this->x_i); + } + + auto apply_solver_tol_ctl() { + std::tuple_element_t<0, T> sol; + return sol(this->f_eigen, this->y0, this->t0, this->ts, this->rtol_f, + this->atol_f, this->rtol_b, this->atol_b, this->rtol_q, + this->atol_q, this->max_num_step, this->num_steps_check, + this->inter_poly, this->solv_f, this->solv_b, nullptr, + this->theta, this->x_r, this->x_i); + } + + void test_bad() { + const auto y0_(this->y0); + + this->y0.resize(0); + EXPECT_THROW_MSG(apply_solver(), std::invalid_argument, + "initial state has size 0"); + this->y0 = y0_; + + const auto t0_ = this->t0; + this->t0 = 2.0; + EXPECT_THROW_MSG(apply_solver(), std::domain_error, + "initial time is 2, but must be less than 0.1"); + this->t0 = t0_; + + const auto ts_ = this->ts; + this->ts.resize(0); + EXPECT_THROW_MSG(apply_solver(), std::invalid_argument, "times has size 0"); + this->ts = ts_; + + this->ts.resize(2); + this->ts[0] = 3.0; + this->ts[1] = 1.0; + EXPECT_THROW_MSG(apply_solver(), std::domain_error, + "times is not a valid sorted vector"); + this->ts = ts_; + + const double rtol_f_ = this->rtol_f; + this->rtol_f = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "relative_tolerance_forward"); + this->rtol_f = rtol_f_; + + const double atol_f_ = this->atol_f(0); + this->atol_f(0) = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "absolute_tolerance_forward"); + this->atol_f(0) = atol_f_; + + this->atol_f.resize(1); + this->atol_f(0) = atol_f_; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::invalid_argument, + "absolute_tolerance_forward"); + this->atol_f.resize(2); + this->atol_f(0) = atol_f_; + this->atol_f(1) = atol_f_; + + const double rtol_b_ = this->rtol_b; + this->rtol_b = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "relative_tolerance_backward"); + this->rtol_b = rtol_b_; + + const double atol_b_ = this->atol_b(0); + this->atol_b(0) = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "absolute_tolerance_backward"); + this->atol_b(0) = atol_b_; + + this->atol_b.resize(1); + this->atol_b(0) = atol_b_; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::invalid_argument, + "absolute_tolerance_backward"); + this->atol_b.resize(2); + this->atol_b(0) = atol_b_; + this->atol_b(1) = atol_b_; + + const double rtol_q_ = this->rtol_q; + this->rtol_q = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "relative_tolerance_quadrature"); + this->rtol_q = rtol_q_; + + const double atol_q_ = this->atol_q; + this->atol_q = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "absolute_tolerance_quadrature"); + this->atol_q = atol_q_; + + const int max_num_step_ = this->max_num_step; + this->max_num_step = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "max_num_steps"); + this->max_num_step = max_num_step_; + + const int num_steps_check_ = this->num_steps_check; + this->num_steps_check = -1; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::domain_error, + "num_steps_between_checkpoints"); + this->num_steps_check = num_steps_check_; + + const int inter_poly_ = this->inter_poly; + this->inter_poly = 0; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::invalid_argument, + "interpolation_polynomial"); + this->inter_poly = inter_poly_; + + const int solv_f_ = this->solv_f; + this->solv_f = 0; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::invalid_argument, + "solver_forward"); + this->solv_f = solv_f_; + + const int solv_b_ = this->solv_b; + this->solv_b = 0; + EXPECT_THROW_MSG(apply_solver_tol_ctl(), std::invalid_argument, + "solver_backward"); + this->solv_b = solv_b_; + + const auto theta_ = this->theta; + const auto x_r_ = this->x_r; + const auto x_i_ = this->x_i; + + // NaN errors + double nan = std::numeric_limits::quiet_NaN(); + std::stringstream expected_is_nan; + expected_is_nan << "is " << nan; + + this->y0[0] = nan; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "initial state"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_nan.str()); + this->y0 = y0_; + + this->t0 = nan; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "initial time"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_nan.str()); + this->t0 = t0_; + + this->ts[0] = nan; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "times"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_nan.str()); + this->ts = ts_; + + this->theta[0] = nan; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "ode parameters and data"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_nan.str()); + this->theta = theta_; + + this->x_r.push_back(nan); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "ode parameters and data"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_nan.str()); + this->x_r = x_r_; + + // inf test + std::stringstream expected_is_inf; + expected_is_inf << "is " << std::numeric_limits::infinity(); + std::stringstream expected_is_neg_inf; + expected_is_neg_inf << "is " << -std::numeric_limits::infinity(); + double inf = std::numeric_limits::infinity(); + + this->y0[0] = inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "initial state"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_inf.str()); + this->y0[0] = -inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "initial state"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_neg_inf.str()); + this->y0 = y0_; + + this->t0 = inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "initial time"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_inf.str()); + this->t0 = -inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "initial time"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_neg_inf.str()); + this->t0 = t0_; + + this->ts.back() = inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "times"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_inf.str()); + this->ts.back() = -inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, "times"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_neg_inf.str()); + this->ts = ts_; + + this->theta[0] = inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "ode parameters and data"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_inf.str()); + this->theta[0] = -inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "ode parameters and data"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_neg_inf.str()); + this->theta = theta_; + + this->x_r = std::vector{inf}; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "ode parameters and data"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_inf.str()); + this->x_r[0] = -inf; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "ode parameters and data"); + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + expected_is_neg_inf.str()); + this->x_r = x_r_; + } + + void test_value(double t0_in) { + this->t0 = t0_in; + for (size_t i = 0; i < this->ts.size(); ++i) { + this->ts[i] = this->t0 + 0.1 * (i + 1); + } + + this->rtol = 1e-8; + this->atol = 1e-10; + this->max_num_step = 1e6; + auto res = apply_solver_tol(); + + EXPECT_NEAR(0.995029, stan::math::value_of(res[0][0]), 1e-5); + EXPECT_NEAR(-0.0990884, stan::math::value_of(res[0][1]), 1e-5); + + EXPECT_NEAR(-0.421907, stan::math::value_of(res[99][0]), 1e-5); + EXPECT_NEAR(0.246407, stan::math::value_of(res[99][1]), 1e-5); + } +}; + +#endif