Skip to content

Commit

Permalink
Add lots of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Jan 12, 2024
1 parent a3c3ef0 commit b7f9f74
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 12 deletions.
10 changes: 9 additions & 1 deletion include/math_approx/src/basic_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ struct scalar_of
{
using type = T;
};

/**
* When T is a scalar floating-point type, scalar_of_t<T> is T.
* When T is a SIMD floating-point type, scalar_of_t<T> is the corresponding scalar type.
*/
template <typename T>
using scalar_of_t = typename scalar_of<T>::type;

/** Inverse square root */
template <typename T>
T rsqrt (T x)
{
Expand All @@ -40,6 +46,7 @@ T rsqrt (T x)
// return x * r;
}

/** Function interface for the ternary operator. */
template <typename T>
T select (bool q, T t, T f)
{
Expand All @@ -53,6 +60,7 @@ struct scalar_of<xsimd::batch<T>>
using type = T;
};

/** Inverse square root */
template <typename T>
xsimd::batch<T> rsqrt (xsimd::batch<T> x)
{
Expand All @@ -65,6 +73,7 @@ xsimd::batch<T> rsqrt (xsimd::batch<T> x)
return x * r;
}

/** Function interface for the ternary operator. */
template <typename T>
xsimd::batch<T> select (xsimd::batch_bool<T> q, xsimd::batch<T> t, xsimd::batch<T> f)
{
Expand Down Expand Up @@ -94,5 +103,4 @@ inline typename std::enable_if<is_bitwise_castable<From, To>::value, To>::type b
#else
using std::bit_cast;
#endif

} // namespace math_approx
6 changes: 5 additions & 1 deletion include/math_approx/src/hyperbolic_trig_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ constexpr auto sinh_cosh (T x)

namespace tanh_detail
{
// These polynomial fits were generated from: https://www.wolframcloud.com/obj/chowdsp/Published/tanh_approx.nb
// See notebooks/tanh_approx.nb for the derivation of these polynomials

template <typename T>
constexpr T tanh_poly_11 (T x)
Expand Down Expand Up @@ -111,6 +111,10 @@ namespace tanh_detail
}
} // namespace tanh_detail

/**
* Approximation of tanh(x), using tanh(x) ≈ p(x) / (p(x)^2 + 1),
* where p(x) is an odd polynomial fit to minimize the maxinimum relative error.
*/
template <int order, typename T>
T tanh (T x)
{
Expand Down
17 changes: 13 additions & 4 deletions include/math_approx/src/inverse_hyperbolic_trig_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace math_approx
{
struct AsinhLog2Provider
{
// for polynomial derivations, see notebooks/asinh_approx.nb

/** approximation for log2(x), optimized on the range [1, 2], to be used within an asinh(x) computation */
template <typename T, int order, bool /*C1_continuous*/>
static constexpr T log2_approx (T x)
Expand Down Expand Up @@ -76,6 +78,10 @@ constexpr T asinh (T x)
return sign * y;
}

/**
* Approximation of acosh(x) in the full range, using identity
* acosh(x) = log(x + sqrt(x^2 - 1)).
*/
template <int order, typename T>
constexpr T acosh (T x)
{
Expand All @@ -85,15 +91,18 @@ constexpr T acosh (T x)
using xsimd::sqrt;
#endif

const auto z0 = x - (S) 1;
const auto z1 = z0 + sqrt (z0 + z0 + z0 * z0);
return log1p<order> (z1);
const auto z1 = x + sqrt (x * x - (S) 1);
return log<order> (z1);
}

/**
* Approximation of atanh(x), using identity
* atanh(x) = (1/2) log((x + 1) / (x - 1)).
*/
template <int order, typename T>
constexpr T atanh (T x)
{
using S = scalar_of_t<T>;
return (S) 0.5 * log<order> (((S) 1 + x) / ((S) 1 - x));
}
}
} // namespace math_approx
17 changes: 17 additions & 0 deletions include/math_approx/src/inverse_trig_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace math_approx
{
namespace inv_trig_detail
{
// for polynomial derivations, see notebooks/asin_acos_approx.nb

template <int order, typename T>
constexpr T asin_kernel (T x)
{
Expand Down Expand Up @@ -66,6 +68,8 @@ namespace inv_trig_detail
}
}

// for polynomial derivations, see notebooks/arctan_approx.nb

template <int order, typename T>
constexpr T atan_kernel (T x)
{
Expand Down Expand Up @@ -100,6 +104,11 @@ namespace inv_trig_detail
}
} // namespace inv_trig_detail

/**
* Approximation of asin(x) using asin(x) ≈ p(x^2) * x^3 + x for x in [0, 0.5],
* and asin(x) ≈ pi/2 - p((1-x)/2) * ((1-x)/2)^3/2 + ((1-x)/2)^1/2 for x in [0.5, 1],
* where p(x) is a polynomial fit to achieve the minimum absolute error.
*/
template <int order, typename T>
T asin (T x)
{
Expand All @@ -123,6 +132,10 @@ T asin (T x)
return select (x > (S) 0, res, -res);
}

/**
* Approximation of acos(x) using the same approach as asin(x),
* but with a different polynomial fit.
*/
template <int order, typename T>
T acos (T x)
{
Expand All @@ -146,6 +159,10 @@ T acos (T x)
return (S) M_PI_2 - select (x > (S) 0, res, -res);
}

/**
* Approximation of atan(x) using a polynomial approximation of arctan(x) on [0, 1],
* and atan(x) = pi/2 - arctan(1/x) for x > 1.
*/
template <int order, typename T>
T atan (T x)
{
Expand Down
15 changes: 15 additions & 0 deletions include/math_approx/src/log_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace log_detail
{
struct Log2Provider
{
// for polynomial derivations, see notebooks/log_approx.nb

/** approximation for log2(x), optimized on the range [1, 2] */
template <typename T, int order, bool C1_continuous>
static constexpr T log2_approx (T x)
Expand Down Expand Up @@ -163,24 +165,37 @@ xsimd::batch<double> log (xsimd::batch<double> x)
#pragma GCC diagnostic pop // end ignore strict-aliasing warnings
#endif

/**
* Approximation of log(x), using
* log(x) = (1 / log2(e)) * (Exponent(x) + log2(1 + Mantissa(x))
*/
template <int order, bool C1_continuous = false, typename T>
constexpr T log (T x)
{
return log<pow_detail::BaseE<scalar_of_t<T>>, order, C1_continuous> (x);
}

/**
* Approximation of log2(x), using
* log2(x) = Exponent(x) + log2(1 + Mantissa(x)
*/
template <int order, bool C1_continuous = false, typename T>
constexpr T log2 (T x)
{
return log<pow_detail::Base2<scalar_of_t<T>>, order, C1_continuous> (x);
}

/**
* Approximation of log10(x), using
* log10(x) = (1 / log2(10)) * (Exponent(x) + log2(1 + Mantissa(x))
*/
template <int order, bool C1_continuous = false, typename T>
constexpr T log10 (T x)
{
return log<pow_detail::Base10<scalar_of_t<T>>, order, C1_continuous> (x);
}

/** Approximation of log(1 + x), using math_approx::log(x) */
template <int order, bool C1_continuous = false, typename T>
constexpr T log1p (T x)
{
Expand Down
2 changes: 2 additions & 0 deletions include/math_approx/src/polylogarithm_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace math_approx
* Orders higher than 3 are generally not recommended for
* single-precision floating-point types, since they don't
* improve the accuracy very much.
*
* For derivations, see notebooks/li2_approx.nb
*/
template <int order, typename T>
constexpr T li2_0_half (T x)
Expand Down
6 changes: 6 additions & 0 deletions include/math_approx/src/pow_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace math_approx
{
namespace pow_detail
{
// for polynomial derivations, see notebooks/exp_approx.nb

/** approximation for 2^x, optimized on the range [0, 1] */
template <typename T, int order, bool C1_continuous>
constexpr T pow2_approx (T x)
Expand Down Expand Up @@ -199,24 +201,28 @@ xsimd::batch<double> pow (xsimd::batch<double> x)
#pragma GCC diagnostic pop // end ignore strict-aliasing warnings
#endif

/** Approximation of exp(x), using exp(x) = 2^floor(x * log2(e)) * 2^frac(x * log2(e)) */
template <int order, bool C1_continuous = false, typename T>
constexpr T exp (T x)
{
return pow<pow_detail::BaseE<scalar_of_t<T>>, order, C1_continuous> (x);
}

/** Approximation of exp2(x), using exp(x) = 2^floor(x) * 2^frac(x) */
template <int order, bool C1_continuous = false, typename T>
constexpr T exp2 (T x)
{
return pow<pow_detail::Base2<scalar_of_t<T>>, order, C1_continuous> (x);
}

/** Approximation of exp(x), using exp10(x) = 2^floor(x * log2(10)) * 2^frac(x * log2(10)) */
template <int order, bool C1_continuous = false, typename T>
constexpr T exp10 (T x)
{
return pow<pow_detail::Base10<scalar_of_t<T>>, order, C1_continuous> (x);
}

/** Approximation of exp(1) - 1, using math_approx::exp(x) */
template <int order, bool C1_continuous = false, typename T>
constexpr T expm1 (T x)
{
Expand Down
7 changes: 6 additions & 1 deletion include/math_approx/src/sigmoid_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace math_approx
{
namespace sigmoid_detail
{
// These polynomial fits were generated from: https://www.wolframcloud.com/obj/chowdsp/Published/sigmoid_approx.nb
// for polynomial derivations, see notebooks/sigmoid_approx.nb

template <typename T>
constexpr T sig_poly_9 (T x)
Expand Down Expand Up @@ -51,6 +51,11 @@ namespace sigmoid_detail
}
} // namespace sigmoid_detail

/**
* Approximation of sigmoid(x) := 1 / (1 + e^-x),
* using sigmoid(x) ≈ (1/2) p(x) / (p(x)^2 + 1) + (1/2),
* where p(x) is an odd polynomial fit to minimize the maxinimum relative error.
*/
template <int order, typename T>
T sigmoid (T x)
{
Expand Down
16 changes: 14 additions & 2 deletions include/math_approx/src/trig_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ namespace trig_detail
return select (x >= (T) 0, mod, mod + pi) - half_pi;
}

// for polynomial derivations, see notebooks/sin_approx.nb

template <typename T>
constexpr T sin_poly_9 (T x, T x_sq)
{
Expand Down Expand Up @@ -79,6 +81,7 @@ namespace trig_detail
}
} // namespace sin_detail

/** Polynomial approximation of sin(x) on the range [-pi, pi] */
template <int order, typename T>
constexpr T sin_mpi_pi (T x)
{
Expand All @@ -100,12 +103,17 @@ constexpr T sin_mpi_pi (T x)
return (pi_sq - x_sq) * x_poly;
}

/** Full range approximation of sin(x) */
template <int order, typename T>
constexpr T sin (T x)
{
return sin_mpi_pi<order, T> (trig_detail::fast_mod_mpi_pi (x));
}

/**
* Polynomial approximation of cos(x) on the range [-pi, pi],
* using a range-shifted approximation of sin(x).
*/
template <int order, typename T>
constexpr T cos_mpi_pi (T x)
{
Expand Down Expand Up @@ -136,18 +144,21 @@ constexpr T cos_mpi_pi (T x)
return (pi_sq - hpmx_sq) * x_poly;
}

/** Full range approximation of cos(x) */
template <int order, typename T>
constexpr T cos (T x)
{
return cos_mpi_pi<order, T> (trig_detail::fast_mod_mpi_pi (x));
}

/** Approximation of tan(x) on the range [-pi/4, pi/4] */
/** Polynomial approximation of tan(x) on the range [-pi/4, pi/4] */
template <int order, typename T>
constexpr T tan_mquarterpi_quarterpi (T x)
{
static_assert (order % 2 == 1 && order >= 3 && order <= 15, "Order must be an odd number within [3, 15]");

// for polynomial derivation, see notebooks/tan_approx.nb

using S = scalar_of_t<T>;
const auto x_sq = x * x;
if constexpr (order == 3)
Expand Down Expand Up @@ -217,7 +228,8 @@ constexpr T tan_mquarterpi_quarterpi (T x)
}

/**
* Approximation of tan(x) on the range [-pi/2, pi/2]
* Approximation of tan(x) on the range [-pi/2, pi/2],
* using the tangent half-angle formula.
*
* Accuracy may suffer as x approaches ±pi/2.
*/
Expand Down
10 changes: 10 additions & 0 deletions include/math_approx/src/wright_omega_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@

namespace math_approx
{
/**
* Approximation of the Wright-Omega function, using
* w(x) ≈ 0 for x < -3
* w(x) ≈ p(x) for -3 <= x < e
* w(x) ≈ x - log(x) + alpha * exp(-beta * x) for x >= e,
* where p(x) is a polynomial, and alpha and beta are coefficients,
* all fit to minimize the maximum absolute error.
*
* The above fit is optionally followed by some number of Newton-Raphson iterations.
*/
template <int num_nr_iters, int poly_order = 3, int log_order = (num_nr_iters <= 1 ? 3 : 4), int exp_order = log_order, typename T>
constexpr T wright_omega (T x)
{
Expand Down
6 changes: 3 additions & 3 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ void plot_function (std::span<const float> all_floats,
int main()
{
plt::figure();
const auto range = std::make_pair (-0.99f, 0.99f);
const auto range = std::make_pair (1.0f, 10.0f);
static constexpr auto tol = 1.0e-2f;

const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol);
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (std::atanh));
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::atanh<5>) ), "atanh_log-5");
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (std::acosh));
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::acosh<5>) ), "acosh-5");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down

0 comments on commit b7f9f74

Please sign in to comment.