diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index 8c2c66ac226..e7dedc1c315 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -307,7 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square, opencl_kernels::inv_square_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit, opencl_kernels::inv_logit_device_function) -ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function) +ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function, + opencl_kernels::logit_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function) ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx, opencl_kernels::inv_logit_device_function, diff --git a/stan/math/opencl/kernels/device_functions/Phi.hpp b/stan/math/opencl/kernels/device_functions/Phi.hpp index ba7ecb6a51f..6946fc6fd28 100644 --- a/stan/math/opencl/kernels/device_functions/Phi.hpp +++ b/stan/math/opencl/kernels/device_functions/Phi.hpp @@ -24,11 +24,11 @@ static const char* phi_device_function if (x < -37.5) { return 0; } else if (x < -5.0) { - return 0.5 * erfc(-1.0 / sqrt(2.0) * x); + return 0.5 * erfc(-M_SQRT1_2 * x); } else if (x > 8.25) { return 1; } else { - return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x)); + return 0.5 * (1.0 + erf(M_SQRT1_2 * x)); } } // \cond diff --git a/stan/math/opencl/kernels/device_functions/inv_logit.hpp b/stan/math/opencl/kernels/device_functions/inv_logit.hpp index 9cb9d56cc27..34c526e4fae 100644 --- a/stan/math/opencl/kernels/device_functions/inv_logit.hpp +++ b/stan/math/opencl/kernels/device_functions/inv_logit.hpp @@ -56,7 +56,7 @@ static const char* inv_logit_device_function */ double inv_logit(double x) { if (x < 0) { - if (x < log(2.2204460492503131E-16)) { + if (x < log(DBL_EPSILON)) { return exp(x); } return exp(x) / (1 + exp(x)); diff --git a/stan/math/opencl/kernels/device_functions/lbeta.hpp b/stan/math/opencl/kernels/device_functions/lbeta.hpp index f0aeb8061fb..0e8e4d139d3 100644 --- a/stan/math/opencl/kernels/device_functions/lbeta.hpp +++ b/stan/math/opencl/kernels/device_functions/lbeta.hpp @@ -95,12 +95,13 @@ static const char* lbeta_device_function return lgamma(x) + lgamma(y) - lgamma(x + y); } double x_over_xy = x / (x + y); + double log_xpy = log(x + y); if (x < LGAMMA_STIRLING_DIFF_USEFUL) { // y large, x small double stirling_diff = lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y); double stirling - = (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y)); + = (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy); return stirling + lgamma(x) + stirling_diff; } @@ -108,8 +109,9 @@ static const char* lbeta_device_function double stirling_diff = lgamma_stirling_diff(x) + lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y); - double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy) - + 0.5 * log(2.0 * M_PI) - 0.5 * log(y); + double stirling = (x - 0.5) * (log(x) - log_xpy) + + y * log1p(-x_over_xy) + + 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y); return stirling + stirling_diff; } // \cond diff --git a/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp b/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp index 19fc87b968f..3934fe11600 100644 --- a/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp +++ b/stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp @@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function * @return Stirling's approximation to lgamma(x). */ double lgamma_stirling(double x) { - return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x; + return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x; } // \cond ) "\n#endif\n"; // NOLINT diff --git a/stan/math/opencl/kernels/device_functions/logit.hpp b/stan/math/opencl/kernels/device_functions/logit.hpp index 8ec1bf8850b..b0b4524214e 100644 --- a/stan/math/opencl/kernels/device_functions/logit.hpp +++ b/stan/math/opencl/kernels/device_functions/logit.hpp @@ -49,7 +49,7 @@ static const char* logit_device_function * @param x argument * @return log odds of argument */ - double logit(double x) { return log(x / (1 - x)); } + double logit(double x) { return log(x) - log1m(x); } // \cond ) "\n#endif\n"; // NOLINT // \endcond diff --git a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp index f79b4449f69..0585df47654 100644 --- a/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp @@ -4,6 +4,7 @@ #include #include +#include namespace stan { namespace math { @@ -92,9 +93,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY( double log_phi = log(phi); double logsumexp_theta_logphi; if (theta > log_phi) { - logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta)); + logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta); } else { - logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi)); + logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi); } double y_plus_phi = y + phi; if (need_logp1) { @@ -196,7 +197,7 @@ const kernel_cl neg_binomial_2_log_glm("neg_binomial_2_log_glm", - {digamma_device_function, + {digamma_device_function, log1p_exp_device_function, neg_binomial_2_log_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); diff --git a/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp index 74341ff76cc..3b1727b1aa9 100644 --- a/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -87,20 +88,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY( if (need_location_derivative || need_cuts_derivative) { double exp_cuts_diff = exp(cut_y2 - cut_y1); - if (cut2 > 0) { - double exp_m_cut2 = exp(-cut2); - d1 = exp_m_cut2 / (1 + exp_m_cut2); - } else { - d1 = 1 / (1 + exp(cut2)); - } + d1 = inv_logit(-cut2); d1 -= exp_cuts_diff / (exp_cuts_diff - 1); d2 = 1 / (1 - exp_cuts_diff); - if (cut1 > 0) { - double exp_m_cut1 = exp(-cut1); - d2 -= exp_m_cut1 / (1 + exp_m_cut1); - } else { - d2 -= 1 / (1 + exp(cut1)); - } + d2 -= inv_logit(-cut1); if (need_location_derivative) { location_derivative[gid] = d1 - d2; @@ -181,6 +172,7 @@ const kernel_cl ordered_logistic_glm("ordered_logistic_glm", {log1p_exp_device_function, log1m_exp_device_function, + inv_logit_device_function, ordered_logistic_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); diff --git a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp index b06a6d47a46..7b5dbddc169 100644 --- a/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp +++ b/stan/math/opencl/kernels/ordered_logistic_lpmf.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -83,20 +84,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY( if (need_lambda_derivative || need_cuts_derivative) { double exp_cuts_diff = exp(cut_y2 - cut_y1); - if (cut2 > 0) { - double exp_m_cut2 = exp(-cut2); - d1 = exp_m_cut2 / (1 + exp_m_cut2); - } else { - d1 = 1 / (1 + exp(cut2)); - } + d1 = inv_logit(-cut2); d1 -= exp_cuts_diff / (exp_cuts_diff - 1); d2 = 1 / (1 - exp_cuts_diff); - if (cut1 > 0) { - double exp_m_cut1 = exp(-cut1); - d2 -= exp_m_cut1 / (1 + exp_m_cut1); - } else { - d2 -= 1 / (1 + exp(cut1)); - } + d2 -= inv_logit(-cut1); if (need_lambda_derivative) { lambda_derivative[gid] = d1 - d2; @@ -175,7 +166,7 @@ const kernel_cl ordered_logistic("ordered_logistic", {log1p_exp_device_function, log1m_exp_device_function, - ordered_logistic_kernel_code}, + inv_logit_device_function, ordered_logistic_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); } // namespace opencl_kernels diff --git a/stan/math/opencl/kernels/tridiagonalization.hpp b/stan/math/opencl/kernels/tridiagonalization.hpp index b0bf3d3288a..0e0c11ca5ed 100644 --- a/stan/math/opencl/kernels/tridiagonalization.hpp +++ b/stan/math/opencl/kernels/tridiagonalization.hpp @@ -84,7 +84,7 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY( q = q_local[0]; alpha = q_local[1]; if (q != 0) { - double multi = sqrt(2.) / q; + double multi = M_SQRT2 / q; // normalize the Householder vector for (int i = lid + 1; i < P_span; i += lsize) { P[P_start + i] *= multi; @@ -92,7 +92,7 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY( } if (gid == 0) { P[P_rows * (k + j + 1) + k + j] - = P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha; + = P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha; } } // \cond @@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY( v[i] -= acc * u[i]; } if (gid == 0) { - P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0]; + P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0]; } } // \cond