Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor cleanup of numerical stability, constants, compound functions in OpenCL kernels #2934

Merged
merged 9 commits into from
Sep 18, 2023
3 changes: 2 additions & 1 deletion stan/math/opencl/kernel_generator/elt_function_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square,
opencl_kernels::inv_square_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit,
opencl_kernels::inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function,
opencl_kernels::logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx,
opencl_kernels::inv_logit_device_function,
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ static const char* phi_device_function
if (x < -37.5) {
return 0;
} else if (x < -5.0) {
return 0.5 * erfc(-1.0 / sqrt(2.0) * x);
return 0.5 * erfc(-M_SQRT1_2 * x);
} else if (x > 8.25) {
return 1;
} else {
return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x));
return 0.5 * (1.0 + erf(M_SQRT1_2 * x));
}
}
// \cond
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernels/device_functions/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static const char* inv_logit_device_function
*/
double inv_logit(double x) {
if (x < 0) {
if (x < log(2.2204460492503131E-16)) {
if (x < log(DBL_EPSILON)) {
return exp(x);
}
return exp(x) / (1 + exp(x));
Expand Down
8 changes: 5 additions & 3 deletions stan/math/opencl/kernels/device_functions/lbeta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,23 @@ static const char* lbeta_device_function
return lgamma(x) + lgamma(y) - lgamma(x + y);
}
double x_over_xy = x / (x + y);
double log_xpy = log(x + y);
if (x < LGAMMA_STIRLING_DIFF_USEFUL) {
// y large, x small
double stirling_diff
= lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
double stirling
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y));
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy);
return stirling + lgamma(x) + stirling_diff;
}

// both large
double stirling_diff = lgamma_stirling_diff(x)
+ lgamma_stirling_diff(y)
- lgamma_stirling_diff(x + y);
double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy)
+ 0.5 * log(2.0 * M_PI) - 0.5 * log(y);
double stirling = (x - 0.5) * (log(x) - log_xpy)
+ y * log1p(-x_over_xy)
+ 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y);
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
return stirling + stirling_diff;
}
// \cond
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function
* @return Stirling's approximation to lgamma(x).
*/
double lgamma_stirling(double x) {
return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x;
return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x;
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernels/device_functions/logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static const char* logit_device_function
* @param x argument
* @return log odds of argument
*/
double logit(double x) { return log(x / (1 - x)); }
double logit(double x) { return log(x) - log1m(x); }
// \cond
) "\n#endif\n"; // NOLINT
// \endcond
Expand Down
7 changes: 4 additions & 3 deletions stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/digamma.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -92,9 +93,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
double log_phi = log(phi);
double logsumexp_theta_logphi;
if (theta > log_phi) {
logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta));
logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta);
} else {
logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi));
logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi);
}
double y_plus_phi = y + phi;
if (need_logp1) {
Expand Down Expand Up @@ -196,7 +197,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
int, int, int, int, int, int, int, int, int>
neg_binomial_2_log_glm("neg_binomial_2_log_glm",
{digamma_device_function,
{digamma_device_function, log1p_exp_device_function,
neg_binomial_2_log_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

Expand Down
16 changes: 4 additions & 12 deletions stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -87,20 +88,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY(

if (need_location_derivative || need_cuts_derivative) {
double exp_cuts_diff = exp(cut_y2 - cut_y1);
if (cut2 > 0) {
double exp_m_cut2 = exp(-cut2);
d1 = exp_m_cut2 / (1 + exp_m_cut2);
} else {
d1 = 1 / (1 + exp(cut2));
}
d1 = inv_logit(-cut2);
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
d2 = 1 / (1 - exp_cuts_diff);
if (cut1 > 0) {
double exp_m_cut1 = exp(-cut1);
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
} else {
d2 -= 1 / (1 + exp(cut1));
}
d2 -= inv_logit(-cut1);

if (need_location_derivative) {
location_derivative[gid] = d1 - d2;
Expand Down Expand Up @@ -181,6 +172,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
ordered_logistic_glm("ordered_logistic_glm",
{log1p_exp_device_function, log1m_exp_device_function,
inv_logit_device_function,
ordered_logistic_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

Expand Down
17 changes: 4 additions & 13 deletions stan/math/opencl/kernels/ordered_logistic_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -83,20 +84,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY(

if (need_lambda_derivative || need_cuts_derivative) {
double exp_cuts_diff = exp(cut_y2 - cut_y1);
if (cut2 > 0) {
double exp_m_cut2 = exp(-cut2);
d1 = exp_m_cut2 / (1 + exp_m_cut2);
} else {
d1 = 1 / (1 + exp(cut2));
}
d1 = inv_logit(-cut2);
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
d2 = 1 / (1 - exp_cuts_diff);
if (cut1 > 0) {
double exp_m_cut1 = exp(-cut1);
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
} else {
d2 -= 1 / (1 + exp(cut1));
}
d2 -= inv_logit(-cut1);

if (need_lambda_derivative) {
lambda_derivative[gid] = d1 - d2;
Expand Down Expand Up @@ -175,7 +166,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer,
in_buffer, int, int, int, int, int, int>
ordered_logistic("ordered_logistic",
{log1p_exp_device_function, log1m_exp_device_function,
ordered_logistic_kernel_code},
inv_logit_device_function, ordered_logistic_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

} // namespace opencl_kernels
Expand Down
6 changes: 3 additions & 3 deletions stan/math/opencl/kernels/tridiagonalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY(
q = q_local[0];
alpha = q_local[1];
if (q != 0) {
double multi = sqrt(2.) / q;
double multi = M_SQRT2 / q;
// normalize the Householder vector
for (int i = lid + 1; i < P_span; i += lsize) {
P[P_start + i] *= multi;
}
}
if (gid == 0) {
P[P_rows * (k + j + 1) + k + j]
= P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha;
= P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
}
}
// \cond
Expand Down Expand Up @@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY(
v[i] -= acc * u[i];
}
if (gid == 0) {
P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0];
P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
}
}
// \cond
Expand Down