Skip to content

Commit

Permalink
Changes per Steve's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Jul 31, 2024
1 parent 40d0e4b commit cf9b012
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
3 changes: 3 additions & 0 deletions stan/math/prim/constraint/sum_to_zero_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline plain_type_t<Vec> sum_to_zero_constrain(const Vec& y) {
const auto Km1 = y.size();
if (unlikely(Km1 == 0)) {
return plain_type_t<Vec>(Eigen::VectorXd{{0}});
}
plain_type_t<Vec> x(Km1 + 1);
// copy the first Km1 elements
auto&& y_ref = to_ref(y);
Expand Down
4 changes: 1 addition & 3 deletions stan/math/prim/constraint/sum_to_zero_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ inline plain_type_t<Vec> sum_to_zero_free(const Vec& x) {
const auto& x_ref = to_ref(x);
check_sum_to_zero("stan::math::sum_to_zero_free", "sum_to_zero variable",
x_ref);
if (x_ref.size() == 0) {
return plain_type_t<Vec>(0);
}

return x_ref.head(x_ref.size() - 1);
}

Expand Down
7 changes: 6 additions & 1 deletion stan/math/prim/err/check_sum_to_zero.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ namespace math {
* @param function Function name (for error messages)
* @param name Variable name (for error messages)
* @param theta Vector to test
* @throw `std::invalid_argument` if `theta` is a 0-vector
* @throw `std::domain_error` if the vector does not sum to zero
*/
template <typename T, require_matrix_t<T>* = nullptr>
void check_sum_to_zero(const char* function, const char* name, const T& theta) {
using std::fabs;
// the size-zero case is technically a valid sum-to-zero vector,
// but it cannot be unconstrained to anything
check_nonzero_size(function, name, theta);
auto&& theta_ref = to_ref(value_of_rec(theta));
if (!(fabs(theta_ref.sum()) <= CONSTRAINT_TOLERANCE)) {
if (unlikely(!(fabs(theta_ref.sum()) <= CONSTRAINT_TOLERANCE))) {
[&]() STAN_COLD_PATH {
std::stringstream msg;
scalar_type_t<T> sum = theta_ref.sum();
Expand All @@ -52,6 +56,7 @@ void check_sum_to_zero(const char* function, const char* name, const T& theta) {
* @param function Function name (for error messages)
* @param name Variable name (for error messages)
* @param theta Vector to test.
* @throw `std::invalid_argument` if `theta` is a 0-vector
* @throw `std::domain_error` if the vector does not sum to zero
*/
template <typename T, require_std_vector_t<T>* = nullptr>
Expand Down
6 changes: 3 additions & 3 deletions stan/math/rev/constraint/sum_to_zero_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ inline auto sum_to_zero_constrain(const T& y) {

const auto N = y.size();
if (unlikely(N == 0)) {
return ret_type(Eigen::VectorXd{{0}});
return arena_t<ret_type>(Eigen::VectorXd{{0}});
}
Eigen::VectorXd x_val = Eigen::VectorXd::Zero(N + 1);
auto arena_y = to_arena(y);
Expand All @@ -43,7 +43,7 @@ inline auto sum_to_zero_constrain(const T& y) {
arena_y.adj().array() -= arena_x.adj_op()(N);
arena_y.adj() += arena_x.adj_op().head(N);
});
return ret_type(arena_x);
return arena_x;
}

/**
Expand All @@ -61,7 +61,7 @@ inline auto sum_to_zero_constrain(const T& y) {
* @return Zero-sum vector of dimensionality K.
*/
template <typename T, require_rev_col_vector_t<T>* = nullptr>
auto sum_to_zero_constrain(const T& y, scalar_type_t<T>& lp) {
inline auto sum_to_zero_constrain(const T& y, scalar_type_t<T>& lp) {
return sum_to_zero_constrain(y);
}

Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/prim/err/check_sum_to_zero_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
TEST(ErrorHandlingMatrix, checkSumToZero_edges) {
Eigen::Matrix<double, Eigen::Dynamic, 1> zero(0);

EXPECT_NO_THROW(
stan::math::check_sum_to_zero("checkSumToZero", "zero", zero));
EXPECT_THROW(stan::math::check_sum_to_zero("checkSumToZero", "zero", zero),
std::invalid_argument);

Eigen::Matrix<double, Eigen::Dynamic, 1> y_vec(1);
y_vec << 0.0;
Expand Down

0 comments on commit cf9b012

Please sign in to comment.