Skip to content

Commit

Permalink
Merge pull request #2934 from stan-dev/opencl-kernel-cleanups
Browse files Browse the repository at this point in the history
Minor cleanup of numerical stability, constants, compound functions in OpenCL kernels
  • Loading branch information
andrjohns authored Sep 18, 2023
2 parents 4cf25de + 8c010a0 commit c001c71
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 40 deletions.
3 changes: 2 additions & 1 deletion stan/math/opencl/kernel_generator/elt_function_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square,
opencl_kernels::inv_square_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit,
opencl_kernels::inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function,
opencl_kernels::logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx,
opencl_kernels::inv_logit_device_function,
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ static const char* phi_device_function
if (x < -37.5) {
return 0;
} else if (x < -5.0) {
return 0.5 * erfc(-1.0 / sqrt(2.0) * x);
return 0.5 * erfc(-M_SQRT1_2 * x);
} else if (x > 8.25) {
return 1;
} else {
return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x));
return 0.5 * (1.0 + erf(M_SQRT1_2 * x));
}
}
// \cond
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernels/device_functions/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static const char* inv_logit_device_function
*/
double inv_logit(double x) {
if (x < 0) {
if (x < log(2.2204460492503131E-16)) {
if (x < log(DBL_EPSILON)) {
return exp(x);
}
return exp(x) / (1 + exp(x));
Expand Down
8 changes: 5 additions & 3 deletions stan/math/opencl/kernels/device_functions/lbeta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,23 @@ static const char* lbeta_device_function
return lgamma(x) + lgamma(y) - lgamma(x + y);
}
double x_over_xy = x / (x + y);
double log_xpy = log(x + y);
if (x < LGAMMA_STIRLING_DIFF_USEFUL) {
// y large, x small
double stirling_diff
= lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
double stirling
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y));
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy);
return stirling + lgamma(x) + stirling_diff;
}

// both large
double stirling_diff = lgamma_stirling_diff(x)
+ lgamma_stirling_diff(y)
- lgamma_stirling_diff(x + y);
double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy)
+ 0.5 * log(2.0 * M_PI) - 0.5 * log(y);
double stirling = (x - 0.5) * (log(x) - log_xpy)
+ y * log1p(-x_over_xy)
+ 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y);
return stirling + stirling_diff;
}
// \cond
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function
* @return Stirling's approximation to lgamma(x).
*/
double lgamma_stirling(double x) {
return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x;
return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x;
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernels/device_functions/logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static const char* logit_device_function
* @param x argument
* @return log odds of argument
*/
double logit(double x) { return log(x / (1 - x)); }
double logit(double x) { return log(x) - log1m(x); }
// \cond
) "\n#endif\n"; // NOLINT
// \endcond
Expand Down
7 changes: 4 additions & 3 deletions stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/digamma.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -92,9 +93,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
double log_phi = log(phi);
double logsumexp_theta_logphi;
if (theta > log_phi) {
logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta));
logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta);
} else {
logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi));
logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi);
}
double y_plus_phi = y + phi;
if (need_logp1) {
Expand Down Expand Up @@ -196,7 +197,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
int, int, int, int, int, int, int, int, int>
neg_binomial_2_log_glm("neg_binomial_2_log_glm",
{digamma_device_function,
{digamma_device_function, log1p_exp_device_function,
neg_binomial_2_log_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

Expand Down
16 changes: 4 additions & 12 deletions stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -87,20 +88,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY(

if (need_location_derivative || need_cuts_derivative) {
double exp_cuts_diff = exp(cut_y2 - cut_y1);
if (cut2 > 0) {
double exp_m_cut2 = exp(-cut2);
d1 = exp_m_cut2 / (1 + exp_m_cut2);
} else {
d1 = 1 / (1 + exp(cut2));
}
d1 = inv_logit(-cut2);
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
d2 = 1 / (1 - exp_cuts_diff);
if (cut1 > 0) {
double exp_m_cut1 = exp(-cut1);
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
} else {
d2 -= 1 / (1 + exp(cut1));
}
d2 -= inv_logit(-cut1);

if (need_location_derivative) {
location_derivative[gid] = d1 - d2;
Expand Down Expand Up @@ -181,6 +172,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
ordered_logistic_glm("ordered_logistic_glm",
{log1p_exp_device_function, log1m_exp_device_function,
inv_logit_device_function,
ordered_logistic_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

Expand Down
17 changes: 4 additions & 13 deletions stan/math/opencl/kernels/ordered_logistic_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -83,20 +84,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY(

if (need_lambda_derivative || need_cuts_derivative) {
double exp_cuts_diff = exp(cut_y2 - cut_y1);
if (cut2 > 0) {
double exp_m_cut2 = exp(-cut2);
d1 = exp_m_cut2 / (1 + exp_m_cut2);
} else {
d1 = 1 / (1 + exp(cut2));
}
d1 = inv_logit(-cut2);
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
d2 = 1 / (1 - exp_cuts_diff);
if (cut1 > 0) {
double exp_m_cut1 = exp(-cut1);
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
} else {
d2 -= 1 / (1 + exp(cut1));
}
d2 -= inv_logit(-cut1);

if (need_lambda_derivative) {
lambda_derivative[gid] = d1 - d2;
Expand Down Expand Up @@ -175,7 +166,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer,
in_buffer, int, int, int, int, int, int>
ordered_logistic("ordered_logistic",
{log1p_exp_device_function, log1m_exp_device_function,
ordered_logistic_kernel_code},
inv_logit_device_function, ordered_logistic_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

} // namespace opencl_kernels
Expand Down
6 changes: 3 additions & 3 deletions stan/math/opencl/kernels/tridiagonalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY(
q = q_local[0];
alpha = q_local[1];
if (q != 0) {
double multi = sqrt(2.) / q;
double multi = M_SQRT2 / q;
// normalize the Householder vector
for (int i = lid + 1; i < P_span; i += lsize) {
P[P_start + i] *= multi;
}
}
if (gid == 0) {
P[P_rows * (k + j + 1) + k + j]
= P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha;
= P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
}
}
// \cond
Expand Down Expand Up @@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY(
v[i] -= acc * u[i];
}
if (gid == 0) {
P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0];
P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
}
}
// \cond
Expand Down

0 comments on commit c001c71

Please sign in to comment.