Skip to content

Commit

Permalink
Adding scalar_of helper
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Nov 22, 2023
1 parent 19c2885 commit 9924fe0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 30 deletions.
19 changes: 17 additions & 2 deletions include/math_approx/src/basic_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@

namespace math_approx
{
template <typename T>
struct scalar_of
{
using type = T;
};
template <typename T>
using scalar_of_t = typename scalar_of<T>::type;

template <typename T>
T rsqrt (T x)
{
Expand All @@ -29,14 +37,21 @@ T rsqrt (T x)
}

#if defined(XSIMD_HPP)
template <typename T>
struct scalar_of<xsimd::batch<T>>
{
using type = T;
};

template <typename T>
xsimd::batch<T> rsqrt (xsimd::batch<T> x)
{
using S = scalar_of_t<T>;
auto r = xsimd::rsqrt (x);
x *= r;
x *= r;
x += -3.0f;
r *= -0.5f;
x += (S) -3;
r *= (S) -0.5;
return x * r;
}
#endif
Expand Down
28 changes: 16 additions & 12 deletions include/math_approx/src/sigmoid_approx.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include "basic_math.hpp"
// #include "pow_approx.hpp"

namespace math_approx
{
Expand All @@ -12,38 +11,42 @@ namespace sigmoid_detail
template <typename T>
T sig_poly_9 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_7_9 = (T) 1.50024356624e-6 + (T) 6.92468584642e-9 * x_sq;
const auto y_5_7_9 = (T) 0.000260923534301 + y_7_9 * x_sq;
const auto y_3_5_7_9 = (T) 0.0208320229264 + y_5_7_9 * x_sq;
const auto y_1_3_5_7_9 = (T) 0.5 + y_3_5_7_9 * x_sq;
const auto y_7_9 = (S) 1.50024356624e-6 + (S) 6.92468584642e-9 * x_sq;
const auto y_5_7_9 = (S) 0.000260923534301 + y_7_9 * x_sq;
const auto y_3_5_7_9 = (S) 0.0208320229264 + y_5_7_9 * x_sq;
const auto y_1_3_5_7_9 = (S) 0.5 + y_3_5_7_9 * x_sq;
return x * y_1_3_5_7_9;
}

template <typename T>
T sig_poly_7 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_5_7 = (T) 0.000255174491559 + (T) 1.90805380557e-6 * x_sq;
const auto y_3_5_7 = (T) 0.0208503675870 + y_5_7 * x_sq;
const auto y_1_3_5_7 = (T) 0.5 + y_3_5_7 * x_sq;
const auto y_5_7 = (S) 0.000255174491559 + (S) 1.90805380557e-6 * x_sq;
const auto y_3_5_7 = (S) 0.0208503675870 + y_5_7 * x_sq;
const auto y_1_3_5_7 = (S) 0.5 + y_3_5_7 * x_sq;
return x * y_1_3_5_7;
}

template <typename T>
T sig_poly_5 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_3_5 = (T) 0.0206108521251 + (T) 0.000307906311109 * x_sq;
const auto y_1_3_5 = (T) 0.5 + y_3_5 * x_sq;
const auto y_3_5 = (S) 0.0206108521251 + (S) 0.000307906311109 * x_sq;
const auto y_1_3_5 = (S) 0.5 + y_3_5 * x_sq;
return x * y_1_3_5;
}

template <typename T>
T sig_poly_3 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_1_3 = (T) 0.5 + (T) 0.0233402955195 * x_sq;
const auto y_1_3 = (S) 0.5 + (S) 0.0233402955195 * x_sq;
return x * y_1_3;
}
} // namespace sigmoid_detail
Expand All @@ -63,7 +66,8 @@ T sigmoid (T x)
else if constexpr (order == 3)
x_poly = sigmoid_detail::sig_poly_3 (x);

return (T) 0.5 * x_poly * rsqrt (x_poly * x_poly + (T) 1) + (T) 0.5;
using S = scalar_of_t<T>;
return (S) 0.5 * x_poly * rsqrt (x_poly * x_poly + (S) 1) + (S) 0.5;
}

// So far this has tested slower than the above approx (for equivalent error),
Expand Down
38 changes: 22 additions & 16 deletions include/math_approx/src/tanh_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,50 +11,55 @@ namespace tanh_detail
template <typename T>
T tanh_poly_11 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_9_11 = (T) 2.63661358122e-6 + (T) 3.33765558362e-8 * x_sq;
const auto y_7_9_11 = (T) 0.000199027336899 + y_9_11 * x_sq;
const auto y_5_7_9_11 = (T) 0.00833223857843 + y_7_9_11 * x_sq;
const auto y_3_5_7_9_11 = (T) 0.166667159320 + y_5_7_9_11 * x_sq;
const auto y_1_3_5_7_9_11 = (T) 1 + y_3_5_7_9_11 * x_sq;
const auto y_9_11 = (S) 2.63661358122e-6 + (S) 3.33765558362e-8 * x_sq;
const auto y_7_9_11 = (S) 0.000199027336899 + y_9_11 * x_sq;
const auto y_5_7_9_11 = (S) 0.00833223857843 + y_7_9_11 * x_sq;
const auto y_3_5_7_9_11 = (S) 0.166667159320 + y_5_7_9_11 * x_sq;
const auto y_1_3_5_7_9_11 = (S) 1 + y_3_5_7_9_11 * x_sq;
return x * y_1_3_5_7_9_11;
}

template <typename T>
T tanh_poly_9 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_7_9 = (T) 0.000192218110330 + (T) 3.54808622170e-6 * x_sq;
const auto y_5_7_9 = (T) 0.00834777254865 + y_7_9 * x_sq;
const auto y_3_5_7_9 = (T) 0.166658873283 + y_5_7_9 * x_sq;
const auto y_1_3_5_7_9 = (T) 1 + y_3_5_7_9 * x_sq;
const auto y_7_9 = (S) 0.000192218110330 + (S) 3.54808622170e-6 * x_sq;
const auto y_5_7_9 = (S) 0.00834777254865 + y_7_9 * x_sq;
const auto y_3_5_7_9 = (S) 0.166658873283 + y_5_7_9 * x_sq;
const auto y_1_3_5_7_9 = (S) 1 + y_3_5_7_9 * x_sq;
return x * y_1_3_5_7_9;
}

template <typename T>
T tanh_poly_7 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_5_7 = (T) 0.00818199927912 + (T) 0.000243153287690 * x_sq;
const auto y_3_5_7 = (T) 0.166769941467 + y_5_7 * x_sq;
const auto y_1_3_5_7 = (T) 1 + y_3_5_7 * x_sq;
const auto y_5_7 = (S) 0.00818199927912 + (S) 0.000243153287690 * x_sq;
const auto y_3_5_7 = (S) 0.166769941467 + y_5_7 * x_sq;
const auto y_1_3_5_7 = (S) 1 + y_3_5_7 * x_sq;
return x * y_1_3_5_7;
}

template <typename T>
T tanh_poly_5 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_3_5 = (T) 0.165326984031 + (T) 0.00970240200826 * x_sq;
const auto y_1_3_5 = (T) 1 + y_3_5 * x_sq;
const auto y_3_5 = (S) 0.165326984031 + (S) 0.00970240200826 * x_sq;
const auto y_1_3_5 = (S) 1 + y_3_5 * x_sq;
return x * y_1_3_5;
}

template <typename T>
T tanh_poly_3 (T x)
{
using S = scalar_of_t<T>;
const auto x_sq = x * x;
const auto y_1_3 = (T) 1 + (T) 0.183428244899 * x_sq;
const auto y_1_3 = (S) 1 + (S) 0.183428244899 * x_sq;
return x * y_1_3;
}
} // namespace tanh_detail
Expand All @@ -76,6 +81,7 @@ T tanh (T x)
else if constexpr (order == 3)
x_poly = tanh_detail::tanh_poly_3 (x);

return x_poly * rsqrt (x_poly * x_poly + (T) 1);
using S = scalar_of_t<T>;
return x_poly * rsqrt (x_poly * x_poly + (S) 1);
}
} // namespace math_approx

0 comments on commit 9924fe0

Please sign in to comment.