Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Numerical Stability of Bernoulli CDF functions #2784

Merged
merged 27 commits into from
Sep 19, 2023
Merged

Improve Numerical Stability of Bernoulli CDF functions #2784

merged 27 commits into from
Sep 19, 2023

Conversation

andrjohns
Copy link
Collaborator

Summary

This PR updates the Bernoulli CDF functions (_cdf, _lcdf, and _lccdf) to operate on the log scale as much as possible, to avoid issues with underflow and resolution around 1

Tests

Additional mix/prob tests have been added to ensure that the gradients aren't impacted (prim behaviour covered by the distribution tests)

Side Effects

N/A

Release notes

Improved numerical stability of Bernoulli CDF functions

Checklist

  • Math issue Improve Bernoulli (LC)CDF Numerical Stability #2783

  • Copyright holder: Andrew Johnson

    The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
    - Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
    - Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

  • the basic tests are passing

    • unit tests pass (to run, use: ./runTests.py test/unit)
    • header checks pass, (make test-headers)
    • dependencies checks pass, (make test-math-dependencies)
    • docs build, (make doxygen)
    • code passes the built in C++ standards checks (make cpplint)
  • the code is written in idiomatic C++ and changes are documented in the doxygen

  • the new changes are tested

Copy link
Contributor

@bob-carpenter bob-carpenter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks---looks great. I think there are a couple name changes that would make this much easier to follow.

stan/math/prim/prob/bernoulli_lcdf.hpp Outdated Show resolved Hide resolved
stan/math/prim/prob/bernoulli_cdf.hpp Outdated Show resolved Hide resolved
stan/math/prim/prob/bernoulli_lcdf.hpp Outdated Show resolved Hide resolved
@bob-carpenter
Copy link
Contributor

@andrjohns: please let me know when this is ready to review and merge. Thanks!

@andrjohns
Copy link
Collaborator Author

Thanks @bob-carpenter! This is ready for another look. I've updated the vectorisation by adding a prim implementation of the select function that is currently used by the OpenCL code for the same purpose (ternary operations that can be agnostic between scalar and Eigen inputs)

@andrjohns
Copy link
Collaborator Author

@SteveBronder when you have a minute (no rush at all), can you have a look at this PR? It involves re-implementing the the OpenCL code's select function, so would be great to have eyes on from someone that knows how it should be behaving

@spinkney
Copy link
Collaborator

@bob-carpenter are you able to re-review this?

@SteveBronder are you able to double check the opencl select code?

@spinkney spinkney closed this Oct 26, 2022
@spinkney spinkney reopened this Oct 26, 2022
@spinkney
Copy link
Collaborator

dismiss the closed/reopened, I hit the trackpad on my laptop by accident

@bob-carpenter bob-carpenter dismissed their stale review October 27, 2022 10:06

don't think I understand the current C++ well enough

@bob-carpenter
Copy link
Contributor

Thanks for the heads up. I just dismissed my review so that someone else could review it. I still don't feel I understand our new C++ conventions well enough to review PRs.

@SteveBronder
Copy link
Collaborator

Sorry didn't have time today but Tuesday I can look at this

Copy link
Collaborator

@SteveBronder SteveBronder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few Qs around the new version of select. I also think we should just write an any() function that for bool just returns the input and for Eigen types holding bools calls the .any() method. Would make things simpler to read

Comment on lines 23 to 25
inline auto select(const bool c, const T_true y_true, const T_false y_false) {
return c ? y_true : y_false;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if @t4c1 still checks github, but I'm not sure if we need common_type here or if auto is fine? I wouldn't mind just using return_type_t<>, though that will only work with arithmetic types since return_type_t has a minimum of double as the returned type. We could just write another another overload to handle the double integral case though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still get notifications if pinged. auto will here be same as T_true (that is how ternary operator works), so some common type is a better idea. Not sure if retrun_type will do promotion to var even if neither T_true nor T_false are var, but we do not want that here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto will here be same as T_true

I've done some tests and it doesn't look like an issue when mixing types: https://godbolt.org/z/dvcxvvxhs

But let me know if I've missed something basic!

Comment on lines 41 to 43
return y_true
.binaryExpr(y_false, [&](auto&& x, auto&& y) { return c ? x : y; })
.eval();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If c is constant here should we just be returning y_true or y_false? We just need to use promotion rules on the output types scalar value with promote_scalar_t<return_type_t<T_true, T_false>>

Comment on lines 65 to 70
if (c) {
return y_true;
}

return y_true.unaryExpr([&](auto&& y) { return y_false; });
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use

if () {
} else {
}

with promote_type_t again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true for all of them.

Comment on lines 133 to 135
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
return c.select(y_true, y_false).eval();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if y_true has a double scalar type and y_false has an integer scalar type?

Comment on lines 57 to 58
}
if (sum(n_arr >= 1)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
if (sum(n_arr >= 1)) {
} else if (sum(n_arr >= 1)) {

Comment on lines 55 to 56
if (sum(n_arr < 0)) {
return ops_partials.build(NEGATIVE_INFTY);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just write an any() function that takes in a scalar or vector and returns true or false. Think that would just be easier to read imo

@andrjohns
Copy link
Collaborator Author

@SteveBronder I'll add an any() function, since I agree that would be super handy. Should I open a separate PR for the select() and any() functions and tests, or are you happy for me to include them in this one?

@andrjohns
Copy link
Collaborator Author

I'll update this PR once the helper functions from #2852 have been added and merged

@andrjohns
Copy link
Collaborator Author

@SteveBronder would you mind having another look at this when you get a minute? No rush

Copy link
Collaborator

@SteveBronder SteveBronder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@andrjohns andrjohns merged commit 9f2689e into stan-dev:develop Sep 19, 2023
7 checks passed
@andrjohns andrjohns deleted the issue-2783-bernoulli-cdf-stable branch September 19, 2023 13:23
@WardBrian
Copy link
Member

develop tests have been failing in the distribution tests since this was merged (https://jenkins.flatironinstitute.org/blue/organizations/jenkins/Stan%2FMath/detail/develop/194/pipeline/500)

I'm guessing this is due to -DSTAN_TEST_ROW_VECTORS which we don't use in CI for PRs

In file included from test/prob/bernoulli/bernoulli_cdf_00000_generated_ffv_test.cpp:3:
In file included from ./test/prob/test_fixture_distr.hpp:4:
In file included from ./stan/math/mix.hpp:4:
In file included from ./stan/math/mix/meta.hpp:6:
In file included from ./stan/math/fwd/core.hpp:4:
In file included from ./stan/math/fwd/core/fvar.hpp:4:
In file included from ./stan/math/prim/meta.hpp:72:
In file included from ./stan/math/prim/meta/append_return_type.hpp:4:
In file included from ./stan/math/prim/fun/Eigen.hpp:22:
In file included from lib/eigen_3.4.0/Eigen/Dense:1:
In file included from lib/eigen_3.4.0/unsupported/Eigen/../../Eigen/Core:295:
lib/eigen_3.4.0/Eigen/src/Core/PlainObjectBase.h:970:7: error: static_assert failed "INVALID_MATRIX_TEMPLATE_PARAMETERS"
      EIGEN_STATIC_ASSERT((EIGEN_IMPLIES(MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1, (int(Options)&RowMajor)==RowMajor)
      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
lib/eigen_3.4.0/Eigen/src/Core/util/StaticAssert.h:33:40: note: expanded from macro 'EIGEN_STATIC_ASSERT'
    #define EIGEN_STATIC_ASSERT(X,MSG) static_assert(X,#MSG);
                                       ^             ~
lib/eigen_3.4.0/Eigen/src/Core/Map.h:159:30: note: in instantiation of member function 'Eigen::PlainObjectBase<Eigen::Array<stan::math::var_value<double, void>, 1, -1, 0, 1, -1> >::_check_template_params' requested here
      PlainObjectType::Base::_check_template_params();
                             ^
./stan/math/rev/core/arena_matrix.hpp:62:9: note: in instantiation of member function 'Eigen::Map<Eigen::Array<stan::math::var_value<double, void>, 1, -1, 0, 1, -1>, 0, Eigen::Stride<0, 0> >::Map' requested here
      : Base::Map(

@andrjohns are you able to take a look at this soon? Otherwise I think we may need to revert this to unblock our CI

@andrjohns
Copy link
Collaborator Author

Ah damn, yeah I'll have a look now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants