From 7825fdf23d5157df30b82928e475f9f23523e331 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 28 Aug 2023 07:23:42 +0300 Subject: [PATCH 1/9] Cleanup numerical stability, constants, compound functions in kernels --- stan/math/opencl/kernels/device_functions/Phi.hpp | 4 ++-- .../math/opencl/kernels/device_functions/lbeta.hpp | 7 ++++--- .../kernels/device_functions/lgamma_stirling.hpp | 2 +- .../math/opencl/kernels/device_functions/logit.hpp | 2 +- .../opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp | 4 ++-- .../opencl/kernels/ordered_logistic_glm_lpmf.hpp | 14 ++------------ stan/math/opencl/kernels/ordered_logistic_lpmf.hpp | 14 ++------------ stan/math/opencl/kernels/tridiagonalization.hpp | 6 +++--- 8 files changed, 17 insertions(+), 36 deletions(-) diff --git a/stan/math/opencl/kernels/device_functions/Phi.hpp b/stan/math/opencl/kernels/device_functions/Phi.hpp index ba7ecb6a51f..6946fc6fd28 100644 --- a/stan/math/opencl/kernels/device_functions/Phi.hpp +++ b/stan/math/opencl/kernels/device_functions/Phi.hpp @@ -24,11 +24,11 @@ static const char* phi_device_function if (x < -37.5) { return 0; } else if (x < -5.0) { - return 0.5 * erfc(-1.0 / sqrt(2.0) * x); + return 0.5 * erfc(-M_SQRT1_2 * x); } else if (x > 8.25) { return 1; } else { - return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x)); + return 0.5 * (1.0 + erf(M_SQRT1_2 * x)); } } // \cond diff --git a/stan/math/opencl/kernels/device_functions/lbeta.hpp b/stan/math/opencl/kernels/device_functions/lbeta.hpp index f0aeb8061fb..e5c377672d4 100644 --- a/stan/math/opencl/kernels/device_functions/lbeta.hpp +++ b/stan/math/opencl/kernels/device_functions/lbeta.hpp @@ -95,12 +95,13 @@ static const char* lbeta_device_function return lgamma(x) + lgamma(y) - lgamma(x + y); } double x_over_xy = x / (x + y); + double log_xpy = log(x + y); if (x < LGAMMA_STIRLING_DIFF_USEFUL) { // y large, x small double stirling_diff = lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y); double stirling - = (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y)); + = (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy); return stirling + lgamma(x) + stirling_diff; } @@ -108,8 +109,8 @@ static const char* lbeta_device_function double stirling_diff = lgamma_stirling_diff(x) + lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y); - double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy) - + 0.5 * log(2.0 * M_PI) - 0.5 * log(y); + double stirling = (x - 0.5) * (log(x) - log_xpy) + y * log1p(-x_over_xy) + + 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y); return stirling + stirling_diff; } // \cond diff --git a/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp b/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp index 19fc87b968f..3934fe11600 100644 --- a/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp +++ b/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp @@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function * @return Stirling's approximation to lgamma(x). */ double lgamma_stirling(double x) { - return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x; + return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x; } // \cond ) "\n#endif\n"; // NOLINT diff --git a/stan/math/opencl/kernels/device_functions/logit.hpp b/stan/math/opencl/kernels/device_functions/logit.hpp index 8ec1bf8850b..b0b4524214e 100644 --- a/stan/math/opencl/kernels/device_functions/logit.hpp +++ b/stan/math/opencl/kernels/device_functions/logit.hpp @@ -49,7 +49,7 @@ static const char* logit_device_function * @param x argument * @return log odds of argument */ - double logit(double x) { return log(x / (1 - x)); } + double logit(double x) { return log(x) - log1m(x); } // \cond ) "\n#endif\n"; // NOLINT // \endcond diff --git a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp index f79b4449f69..b52284c0faa 100644 --- a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp @@ -92,9 +92,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY( double log_phi = log(phi); double logsumexp_theta_logphi; if (theta > log_phi) { - logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta)); + logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta); } else { - logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi)); + logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi); } double y_plus_phi = y + phi; if (need_logp1) { diff --git a/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp index 74341ff76cc..2f07e2dae93 100644 --- a/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp @@ -87,20 +87,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY( if (need_location_derivative || need_cuts_derivative) { double exp_cuts_diff = exp(cut_y2 - cut_y1); - if (cut2 > 0) { - double exp_m_cut2 = exp(-cut2); - d1 = exp_m_cut2 / (1 + exp_m_cut2); - } else { - d1 = 1 / (1 + exp(cut2)); - } + d1 = inv_logit(-cut2); d1 -= exp_cuts_diff / (exp_cuts_diff - 1); d2 = 1 / (1 - exp_cuts_diff); - if (cut1 > 0) { - double exp_m_cut1 = exp(-cut1); - d2 -= exp_m_cut1 / (1 + exp_m_cut1); - } else { - d2 -= 1 / (1 + exp(cut1)); - } + d2 -= inv_logit(-cut1); if (need_location_derivative) { location_derivative[gid] = d1 - d2; diff --git a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp index b06a6d47a46..89dda0d287b 100644 --- a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp @@ -83,20 +83,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY( if (need_lambda_derivative || need_cuts_derivative) { double exp_cuts_diff = exp(cut_y2 - cut_y1); - if (cut2 > 0) { - double exp_m_cut2 = exp(-cut2); - d1 = exp_m_cut2 / (1 + exp_m_cut2); - } else { - d1 = 1 / (1 + exp(cut2)); - } + d1 = inv_logit(-cut2); d1 -= exp_cuts_diff / (exp_cuts_diff - 1); d2 = 1 / (1 - exp_cuts_diff); - if (cut1 > 0) { - double exp_m_cut1 = exp(-cut1); - d2 -= exp_m_cut1 / (1 + exp_m_cut1); - } else { - d2 -= 1 / (1 + exp(cut1)); - } + d2 -= inv_logit(-cut1); if (need_lambda_derivative) { lambda_derivative[gid] = d1 - d2; diff --git a/stan/math/opencl/kernels/tridiagonalization.hpp b/stan/math/opencl/kernels/tridiagonalization.hpp index b0bf3d3288a..0e0c11ca5ed 100644 --- a/stan/math/opencl/kernels/tridiagonalization.hpp +++ b/stan/math/opencl/kernels/tridiagonalization.hpp @@ -84,7 +84,7 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY( q = q_local[0]; alpha = q_local[1]; if (q != 0) { - double multi = sqrt(2.) / q; + double multi = M_SQRT2 / q; // normalize the Householder vector for (int i = lid + 1; i < P_span; i += lsize) { P[P_start + i] *= multi; @@ -92,7 +92,7 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY( } if (gid == 0) { P[P_rows * (k + j + 1) + k + j] - = P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha; + = P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha; } } // \cond @@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY( v[i] -= acc * u[i]; } if (gid == 0) { - P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0]; + P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0]; } } // \cond From 71653a3141caf154238de6c4959ac9db64719dea Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 28 Aug 2023 00:27:40 -0400 Subject: [PATCH 2/9] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/opencl/kernels/device_functions/lbeta.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stan/math/opencl/kernels/device_functions/lbeta.hpp b/stan/math/opencl/kernels/device_functions/lbeta.hpp index e5c377672d4..0e8e4d139d3 100644 --- a/stan/math/opencl/kernels/device_functions/lbeta.hpp +++ b/stan/math/opencl/kernels/device_functions/lbeta.hpp @@ -109,7 +109,8 @@ static const char* lbeta_device_function double stirling_diff = lgamma_stirling_diff(x) + lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y); - double stirling = (x - 0.5) * (log(x) - log_xpy) + y * log1p(-x_over_xy) + double stirling = (x - 0.5) * (log(x) - log_xpy) + + y * log1p(-x_over_xy) + 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y); return stirling + stirling_diff; } From 7e4c5912551ed0586b337523985bd2b767844706 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 28 Aug 2023 10:11:12 +0300 Subject: [PATCH 3/9] Fix kernel inclusion --- stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp | 2 ++ stan/math/opencl/kernels/ordered_logistic_lpmf.hpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp index 2f07e2dae93..3b1727b1aa9 100644 --- a/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -171,6 +172,7 @@ const kernel_cl ordered_logistic_glm("ordered_logistic_glm", {log1p_exp_device_function, log1m_exp_device_function, + inv_logit_device_function, ordered_logistic_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); diff --git a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp index 89dda0d287b..a77f9e64688 100644 --- a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -165,6 +166,7 @@ const kernel_cl ordered_logistic("ordered_logistic", {log1p_exp_device_function, log1m_exp_device_function, + inv_logit_device_function, ordered_logistic_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); From febea5f092078187eea18208f68c6d788d102b3e Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 28 Aug 2023 03:12:38 -0400 Subject: [PATCH 4/9] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/opencl/kernels/ordered_logistic_lpmf.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp index a77f9e64688..7b5dbddc169 100644 --- a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp @@ -166,8 +166,7 @@ const kernel_cl ordered_logistic("ordered_logistic", {log1p_exp_device_function, log1m_exp_device_function, - inv_logit_device_function, - ordered_logistic_kernel_code}, + inv_logit_device_function, ordered_logistic_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); } // namespace opencl_kernels From 0a2a902c9d4081b11e98ad34b9175e14dd2c60a1 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 28 Aug 2023 11:44:24 +0300 Subject: [PATCH 5/9] Missing kernel inclide --- stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp index b52284c0faa..a87f20b4d47 100644 --- a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp @@ -4,6 +4,7 @@ #include #include +#include namespace stan { namespace math { @@ -197,6 +198,7 @@ const kernel_cl neg_binomial_2_log_glm("neg_binomial_2_log_glm", {digamma_device_function, + log1p_exp_device_function, neg_binomial_2_log_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); From 9c219fe20fedcb202857040eb5f1c49e04d8ac57 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 28 Aug 2023 04:45:32 -0400 Subject: [PATCH 6/9] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp index a87f20b4d47..0585df47654 100644 --- a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp @@ -197,8 +197,7 @@ const kernel_cl neg_binomial_2_log_glm("neg_binomial_2_log_glm", - {digamma_device_function, - log1p_exp_device_function, + {digamma_device_function, log1p_exp_device_function, neg_binomial_2_log_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); From 97c3d7f15007d6df1ba47c8c5218f7f1ddc2d8a9 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 28 Aug 2023 14:14:47 +0300 Subject: [PATCH 7/9] logit include --- stan/math/opencl/kernel_generator/elt_function_cl.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index 8c2c66ac226..87388162f18 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -307,7 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square, opencl_kernels::inv_square_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit, opencl_kernels::inv_logit_device_function) -ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function) +ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function, + opencl_kernels::log1m_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx, opencl_kernels::inv_logit_device_function, From 0aae3775eff72438335b6d86b39b4d81e6f4cc97 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 28 Aug 2023 20:10:12 +0300 Subject: [PATCH 8/9] Fix test failures --- stan/math/opencl/kernel_generator/elt_function_cl.hpp | 4 ++-- stan/math/opencl/kernels/device_functions/inv_logit.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index 87388162f18..a2851709e98 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -307,8 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square, opencl_kernels::inv_square_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit, opencl_kernels::inv_logit_device_function) -ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function, - opencl_kernels::log1m_device_function) +ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function, + opencl_kernels::logit_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx, opencl_kernels::inv_logit_device_function, diff --git a/stan/math/opencl/kernels/device_functions/inv_logit.hpp b/stan/math/opencl/kernels/device_functions/inv_logit.hpp index 9cb9d56cc27..34c526e4fae 100644 --- a/stan/math/opencl/kernels/device_functions/inv_logit.hpp +++ b/stan/math/opencl/kernels/device_functions/inv_logit.hpp @@ -56,7 +56,7 @@ static const char* inv_logit_device_function */ double inv_logit(double x) { if (x < 0) { - if (x < log(2.2204460492503131E-16)) { + if (x < log(DBL_EPSILON)) { return exp(x); } return exp(x) / (1 + exp(x)); From 8c010a0c345728706de1f24fc9203792d8613f5a Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 28 Aug 2023 13:11:15 -0400 Subject: [PATCH 9/9] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/opencl/kernel_generator/elt_function_cl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index a2851709e98..e7dedc1c315 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -308,7 +308,7 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square, ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit, opencl_kernels::inv_logit_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function, - opencl_kernels::logit_device_function) + opencl_kernels::logit_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx, opencl_kernels::inv_logit_device_function,