Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Numerical Stability of Bernoulli CDF functions #2784

Merged
merged 27 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2b8ae14
Numerically stable bernoulli cdf functions
andrjohns Jul 8, 2022
c2022be
CDF initial value
andrjohns Jul 8, 2022
a9d0800
cpplint
andrjohns Jul 8, 2022
bcfaec1
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 8, 2022
143f707
Add select function, vectorise cdf
andrjohns Jul 10, 2022
62d8640
Vectorise lcdf and lccdf
andrjohns Jul 10, 2022
0edeb5f
Update doc
andrjohns Jul 10, 2022
fd6d39c
Merge commit 'e4b9bdece4250e3455d663e3155c1d3d4965c10d' into HEAD
yashikno Jul 10, 2022
f03c35f
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 10, 2022
2b3cf01
Fix headers
andrjohns Jul 10, 2022
a371c2b
Fix return types
andrjohns Jul 11, 2022
2bade77
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 11, 2022
409c971
Ignore Eigen deprecation warning in expression tests
andrjohns Jul 11, 2022
6220f14
Update compiler flags
andrjohns Jul 12, 2022
2c11c42
Update includes
andrjohns Jul 12, 2022
136fc9c
Merge branch 'stan-dev:develop' into issue-2783-bernoulli-cdf-stable
andrjohns Sep 4, 2022
ccde5d5
Merge branch 'stan-dev:develop' into issue-2783-bernoulli-cdf-stable
andrjohns Oct 26, 2022
a4b384c
Merge branch 'develop' into issue-2783-bernoulli-cdf-stable
andrjohns Nov 14, 2022
da1f7a8
review comments
andrjohns Nov 14, 2022
abfa364
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Nov 14, 2022
6103b16
Merge branch 'stan-dev:develop' into issue-2783-bernoulli-cdf-stable
andrjohns Dec 2, 2022
53ce103
Merge branch 'develop' into issue-2783-bernoulli-cdf-stable
andrjohns Aug 11, 2023
9872fcf
Tidy includes
andrjohns Aug 11, 2023
bd11e54
Missing headers
andrjohns Aug 11, 2023
8f5cdb8
Reduce unnecessary computation
andrjohns Aug 11, 2023
cabafd6
Remove select broadcast hack
andrjohns Aug 16, 2023
47c4f9a
Fix broadcast logic
andrjohns Aug 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ endif
CXXFLAGS_OS += -D_REENTRANT

## silence warnings occuring due to the TBB and Eigen libraries
CXXFLAGS_WARNINGS += -Wno-ignored-attributes
CXXFLAGS_WARNINGS += -Wno-ignored-attributes -Wno-deprecated-declarations

################################################################################
# Setup OpenCL
Expand Down
17 changes: 1 addition & 16 deletions stan/math/opencl/kernel_generator/select.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifdef STAN_OPENCL

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
Expand Down Expand Up @@ -150,22 +151,6 @@ select(T_condition&& condition, T_then&& then, T_else&& els) { // NOLINT
as_operation_cl(std::forward<T_else>(els))};
}

/**
* Scalar overload of the selection operation.
* @tparam T_then type of then scalar
* @tparam T_else type of else scalar
* @param condition condition
* @param then then result
* @param els else result
* @return `condition ? then : els`
*/
template <typename T_then, typename T_else,
require_all_arithmetic_t<T_then, T_else>* = nullptr>
inline std::common_type_t<T_then, T_else> select(bool condition, T_then then,
T_else els) {
return condition ? then : els;
}

/** @}*/
} // namespace math
} // namespace stan
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@
#include <stan/math/prim/fun/scaled_add.hpp>
#include <stan/math/prim/fun/sd.hpp>
#include <stan/math/prim/fun/segment.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/signbit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
Expand Down
140 changes: 140 additions & 0 deletions stan/math/prim/fun/select.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#ifndef STAN_MATH_PRIM_FUN_SELECT_HPP
#define STAN_MATH_PRIM_FUN_SELECT_HPP

#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {

/**
* Return the second argument if the first argument is true
* and otherwise return the third argument.
*
* <code>select(c, y1, y0) = c ? y1 : y0</code>.
*
* @tparam T_true type of the true argument
* @tparam T_false type of the false argument
* @param c Boolean condition value.
* @param y_true Value to return if condition is true.
* @param y_false Value to return if condition is false.
*/
template <typename T_true, typename T_false,
require_all_stan_scalar_t<T_true, T_false>* = nullptr>
inline auto select(const bool c, const T_true y_true, const T_false y_false) {
return c ? y_true : y_false;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if @t4c1 still checks github, but I'm not sure if we need common_type here or if auto is fine? I wouldn't mind just using return_type_t<>, though that will only work with arithmetic types since return_type_t has a minimum of double as the returned type. We could just write another another overload to handle the double integral case though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still get notifications if pinged. auto will here be same as T_true (that is how ternary operator works), so some common type is a better idea. Not sure if retrun_type will do promotion to var even if neither T_true nor T_false are var, but we do not want that here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto will here be same as T_true

I've done some tests and it doesn't look like an issue when mixing types: https://godbolt.org/z/dvcxvvxhs

But let me know if I've missed something basic!


/**
* Return the second argument if the first argument is true
* and otherwise return the third argument. Overload for use with two Eigen
* objects.
*
* @tparam T_true type of the true argument
* @tparam T_false type of the false argument
* @param c Boolean condition value.
* @param y_true Value to return if condition is true.
* @param y_false Value to return if condition is false.
*/
template <typename T_true, typename T_false,
require_all_eigen_t<T_true, T_false>* = nullptr>
inline auto select(const bool c, const T_true y_true, const T_false y_false) {
return y_true
.binaryExpr(y_false, [&](auto&& x, auto&& y) { return c ? x : y; })
.eval();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If c is constant here should we just be returning y_true or y_false? We just need to use promotion rules on the output types scalar value with promote_scalar_t<return_type_t<T_true, T_false>>

}

/**
* Return the second Eigen argument if the first argument is true
* and otherwise return the second Eigen argument. Overload for use with one
* scalar and one Eigen object. If chosen, the scalar is returned as an Eigen
* object of the same size and type as the provided argument.
*
* @tparam T_true type of the true argument
* @tparam T_false type of the false argument
* @param c Boolean condition value.
* @param y_true Value to return if condition is true.
* @param y_false Value to return if condition is false.
*/
template <typename T_true, typename T_false,
typename ReturnT = promote_scalar_t<return_type_t<T_true, T_false>,
plain_type_t<T_true>>,
require_eigen_t<T_true>* = nullptr,
require_stan_scalar_t<T_false>* = nullptr>
inline ReturnT select(const bool c, const T_true& y_true,
const T_false& y_false) {
if (c) {
return y_true;
}

return y_true.unaryExpr([&](auto&& y) { return y_false; });
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use

if () {
} else {
}

with promote_type_t again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true for all of them.


/**
* Return the second Eigen argument if the first argument is true
* and otherwise return the second Eigen argument. Overload for use with one
* scalar and one Eigen object. If chosen, the scalar is returned as an Eigen
* object of the same size and type as the provided argument.
*
* @tparam T_true type of the true argument
* @tparam T_false type of the false argument
* @param c Boolean condition value.
* @param y_true Value to return if condition is true.
* @param y_false Value to return if condition is false.
*/
template <typename T_true, typename T_false,
typename ReturnT = promote_scalar_t<return_type_t<T_true, T_false>,
plain_type_t<T_false>>,
require_stan_scalar_t<T_true>* = nullptr,
require_eigen_t<T_false>* = nullptr>
inline ReturnT select(const bool c, const T_true y_true,
const T_false y_false) {
if (c) {
return y_false.unaryExpr([&](auto&& y) { return y_true; });
}

return y_false;
}

/**
* Return the second argument if the first argument is true
* and otherwise return the third argument. Overload for use with an Eigen
* object of booleans, and two scalars. The chosen scalar is returned as an
* Eigen object of the same dimension as the input Eigen argument
*
* @tparam T_bool type of Eigen boolean object
* @tparam T_true type of the true argument
* @tparam T_false type of the false argument
* @param c Eigen object of boolean condition values.
* @param y_true Value to return if condition is true.
* @param y_false Value to return if condition is false.
*/
template <typename T_bool, typename T_true, typename T_false,
require_eigen_array_t<T_bool>* = nullptr,
require_all_stan_scalar_t<T_true, T_false>* = nullptr>
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
return c.unaryExpr([&](bool cond) { return cond ? y_true : y_false; }).eval();
}

/**
* Return the second argument if the first argument is true
* and otherwise return the third argument. Overload for use with an Eigen
* object of booleans, and at least one Eigen object as input.
*
* @tparam T_bool type of Eigen boolean object
* @tparam T_true type of the true argument
* @tparam T_false type of the false argument
* @param c Eigen object of boolean condition values.
* @param y_true Value to return if condition is true.
* @param y_false Value to return if condition is false.
*/
template <typename T_bool, typename T_true, typename T_false,
require_eigen_array_t<T_bool>* = nullptr,
require_any_eigen_array_t<T_true, T_false>* = nullptr>
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
return c.select(y_true, y_false).eval();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if y_true has a double scalar type and y_false has an integer scalar type?


} // namespace math
} // namespace stan

#endif
38 changes: 10 additions & 28 deletions stan/math/prim/prob/bernoulli_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -36,50 +37,31 @@ return_type_t<T_prob> bernoulli_cdf(const T_n& n, const T_prob& theta) {
check_consistent_sizes(function, "Random variable", n,
"Probability parameter", theta);
T_theta_ref theta_ref = theta;
const auto& n_arr = as_array_or_scalar(n);
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
1.0);

if (size_zero(n, theta)) {
return 1.0;
}

T_partials_return P(1.0);
operands_and_partials<T_theta_ref> ops_partials(theta_ref);

scalar_seq_view<T_n> n_vec(n);
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
size_t max_size_seq_view = max_size(n, theta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(n); i++) {
if (n_vec.val(i) < 0) {
return ops_partials.build(0.0);
}
if (sum(n_arr < 0)) {
return ops_partials.build(0.0);
}
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
const auto& log1m_theta = select(theta_arr == 1, 0.0, log1m(theta_arr));
const auto& P1 = select(n_arr == 0, log1m_theta, 0.0);

for (size_t i = 0; i < max_size_seq_view; i++) {
// Explicit results for extreme values
// The gradients are technically ill-defined, but treated as zero
if (n_vec.val(i) >= 1) {
continue;
}

const T_partials_return Pi = 1 - theta_vec.val(i);

P *= Pi;

if (!is_constant_all<T_prob>::value) {
ops_partials.edge1_.partials_[i] += -1 / Pi;
}
}
T_partials_return P = sum(P1);

if (!is_constant_all<T_prob>::value) {
for (size_t i = 0; i < stan::math::size(theta); ++i) {
ops_partials.edge1_.partials_[i] *= P;
}
ops_partials.edge1_.partials_ = select(n_arr == 0, -exp(P - P1), 0.0);
}
return ops_partials.build(P);
return ops_partials.build(exp(P));
}

} // namespace math
Expand Down
33 changes: 11 additions & 22 deletions stan/math/prim/prob/bernoulli_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
Expand All @@ -33,50 +34,38 @@ template <typename T_n, typename T_prob,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_n, T_prob>* = nullptr>
return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) {
using T_partials_return = partials_return_t<T_n, T_prob>;
using T_theta_ref = ref_type_t<T_prob>;
using std::log;
static const char* function = "bernoulli_lccdf";
check_consistent_sizes(function, "Random variable", n,
"Probability parameter", theta);
T_theta_ref theta_ref = theta;
const auto& n_arr = as_array_or_scalar(n);
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
1.0);

if (size_zero(n, theta)) {
return 0.0;
}

T_partials_return P(0.0);
operands_and_partials<T_theta_ref> ops_partials(theta_ref);

scalar_seq_view<T_n> n_vec(n);
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
size_t max_size_seq_view = max_size(n, theta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(n); i++) {
const double n_dbl = n_vec.val(i);
if (n_dbl < 0) {
return ops_partials.build(0.0);
}
if (n_dbl >= 1) {
return ops_partials.build(NEGATIVE_INFTY);
}
if (sum(n_arr < 0)) {
return ops_partials.build(0.0);
}
if (sum(n_arr >= 1)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
if (sum(n_arr >= 1)) {
} else if (sum(n_arr >= 1)) {

return ops_partials.build(NEGATIVE_INFTY);
}

for (size_t i = 0; i < max_size_seq_view; i++) {
const T_partials_return Pi = theta_vec.val(i);

P += log(Pi);
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);

if (!is_constant_all<T_prob>::value) {
ops_partials.edge1_.partials_[i] += inv(Pi);
}
if (!is_constant_all<T_prob>::value) {
ops_partials.edge1_.partials_ = select(true, inv(theta_arr), n_arr);
}

return ops_partials.build(P);
return ops_partials.build(sum(select(true, log(theta_arr), n_arr)));
}

} // namespace math
Expand Down
33 changes: 9 additions & 24 deletions stan/math/prim/prob/bernoulli_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
Expand All @@ -33,52 +34,36 @@ template <typename T_n, typename T_prob,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_n, T_prob>* = nullptr>
return_type_t<T_prob> bernoulli_lcdf(const T_n& n, const T_prob& theta) {
using T_partials_return = partials_return_t<T_n, T_prob>;
using T_theta_ref = ref_type_t<T_prob>;
using std::log;
static const char* function = "bernoulli_lcdf";
check_consistent_sizes(function, "Random variable", n,
"Probability parameter", theta);
T_theta_ref theta_ref = theta;
const auto& n_arr = as_array_or_scalar(n);
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
1.0);

if (size_zero(n, theta)) {
return 0.0;
}

T_partials_return P(0.0);
operands_and_partials<T_theta_ref> ops_partials(theta_ref);

scalar_seq_view<T_n> n_vec(n);
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
size_t max_size_seq_view = max_size(n, theta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(n); i++) {
if (n_vec.val(i) < 0) {
return ops_partials.build(NEGATIVE_INFTY);
}
if (sum(n_arr < 0)) {
return ops_partials.build(NEGATIVE_INFTY);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just write an any() function that takes in a scalar or vector and returns true or false. Think that would just be easier to read imo

}

for (size_t i = 0; i < max_size_seq_view; i++) {
// Explicit results for extreme values
// The gradients are technically ill-defined, but treated as zero
if (n_vec.val(i) >= 1) {
continue;
}

const T_partials_return Pi = 1 - theta_vec.val(i);

P += log(Pi);
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
const auto& log1m_theta = select(theta_arr == 1, 0.0, log1m(theta_arr));

if (!is_constant_all<T_prob>::value) {
ops_partials.edge1_.partials_[i] -= inv(Pi);
}
if (!is_constant_all<T_prob>::value) {
ops_partials.edge1_.partials_ = select(n_arr == 0, -exp(-log1m_theta), 0.0);
}

return ops_partials.build(P);
return ops_partials.build(sum(select(n_arr == 0, log1m_theta, 0.0)));
}

} // namespace math
Expand Down
Loading