From 1063b55c5d33072753af3f8b01bf1828dcd37d6b Mon Sep 17 00:00:00 2001 From: jatin Date: Mon, 4 Dec 2023 01:17:16 -0800 Subject: [PATCH] Add polylogarithm function implementations --- README.md | 1 + include/math_approx/math_approx.hpp | 1 + .../math_approx/src/polylogarithm_approx.hpp | 221 ++++++++++++++++++ test/CMakeLists.txt | 1 + test/src/polylog_approx_test.cpp | 70 ++++++ test/src/reference/polylogarithm.hpp | 98 ++++++++ tools/bench/CMakeLists.txt | 3 + tools/bench/polylog_bench.cpp | 53 +++++ tools/plotter/plotter.cpp | 16 +- 9 files changed, 453 insertions(+), 11 deletions(-) create mode 100644 include/math_approx/src/polylogarithm_approx.hpp create mode 100644 test/src/polylog_approx_test.cpp create mode 100644 test/src/reference/polylogarithm.hpp create mode 100644 tools/bench/polylog_bench.cpp diff --git a/README.md b/README.md index 228783a..8498fc4 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,4 @@ Currently supported: - tanh - sigmoid - Wright-Omega function +- Dilogarithm function diff --git a/include/math_approx/math_approx.hpp b/include/math_approx/math_approx.hpp index c68ad8f..bcedaf2 100644 --- a/include/math_approx/math_approx.hpp +++ b/include/math_approx/math_approx.hpp @@ -12,3 +12,4 @@ namespace math_approx #include "src/pow_approx.hpp" #include "src/log_approx.hpp" #include "src/wright_omega_approx.hpp" +#include "src/polylogarithm_approx.hpp" diff --git a/include/math_approx/src/polylogarithm_approx.hpp b/include/math_approx/src/polylogarithm_approx.hpp new file mode 100644 index 0000000..c3dbe82 --- /dev/null +++ b/include/math_approx/src/polylogarithm_approx.hpp @@ -0,0 +1,221 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +/** + * Approximation of the "dilogarithm" function for inputs + * in the range [0, 1/2]. This method does not do any + * bounds-checking. + * + * Orders higher than 3 are generally not recommended for + * single-precision floating-point types, since they don't + * improve the accuracy very much. + */ +template +T li2_0_half (T x) +{ + static_assert (order >= 1 && order <= 6); + using S = scalar_of_t; + + if constexpr (order == 1) + { + const auto n_0 = (S) 0.996460629617; + const auto d_0_1 = (S) 1 + (S) -0.288575624121 * x; + return x * n_0 / d_0_1; + } + else if constexpr (order == 2) + { + const auto n_0_1 = (S) 0.999994847641 + (S) -0.546961998015 * x; + const auto d_1_2 = (S) -0.797206910618 + (S) 0.0899936224040 * x; + const auto d_0_1_2 = (S) 1 + d_1_2 * x; + return x * n_0_1 / d_0_1_2; + } + else if constexpr (order == 3) + { + const auto x_sq = x * x; + const auto n_0_2 = (S) 0.999999991192 + (S) 0.231155739205 * x_sq; + const auto n_0_1_2 = n_0_2 + (S) -1.07612533343 * x; + const auto d_2_3 = (S) 0.451592861555 + (S) -0.0281544399023 * x; + const auto d_0_1 = (S) 1 + (S) -1.32612627824 * x; + const auto d_0_1_2_3 = d_0_1 + d_2_3 * x_sq; + return x * n_0_1_2 / d_0_1_2_3; + } + else if constexpr (order == 4) + { + const auto x_sq = x * x; + const auto n_2_3 = (S) 0.74425269014090502911555775982556365472 + (S) -0.08749607277005140673532964399704145939 * x; + const auto n_0_1 = (S) 0.99999999998544094594795118478024862055 + (S) -1.6098648159028159794757437744309391591 * x; + const auto n_0_1_2_3 = n_0_1 + n_2_3 * x_sq; + const auto d_3_4 = (S) -0.21787247785577362691148412819704459614 + (S) 0.00870385570778120787932426702624346169 * x; + const auto d_1_2 = (S) -1.85986481869406218896935179306183665107 + (S) 1.09810787318601772062220747277929300408 * x; + const auto d_1_2_3_4 = d_1_2 + d_3_4 * x_sq; + const auto d_0_1_2_3_4 = (S) 1 + d_1_2_3_4 * x; + return x * n_0_1_2_3 / d_0_1_2_3_4; + } + else if constexpr (order == 5) + { + const auto x_sq = x * x; + + const auto n_3_4 = (S) -0.41945653857264507277532555842378439927 + (S) 0.03140351694981020435408321943912212079 * x; + const auto n_1_2 = (S) -2.14843104749890205674150618938194330623 + (S) 1.54956546570292751217524363072830456069 * x; + const auto n_1_2_3_4 = n_1_2 + n_3_4 * x_sq; + const auto n_0_1_2_3_4 = (S) 0.99999999999997312289180148636206726177 + n_1_2_3_4 * x; + + const auto d_4_5 = (S) 0.09609912057603552016206051904306797162 + (S) -0.00269129500193871901659324657805482418 * x; + const auto d_2_3 = (S) 2.03806211686824385201410542913121040892 + (S) -0.72497973694183708484311198715866984035 * x; + const auto d_0_1 = (S) 1 + (S) -2.398431047506893407956406025441134862 * x; + const auto d_2_3_4_5 = d_2_3 + d_4_5 * x_sq; + const auto d_0_1_2_3_4_5 = d_0_1 + d_2_3_4_5 * x_sq; + + return x * n_0_1_2_3_4 / d_0_1_2_3_4_5; + } + else if constexpr (order == 6) + { + const auto x_sq = x * x; + + const auto n_4_5 = (S) 0.20885966267164674441979654645138181067 + (S) -0.01085968986663512120143497781484214416 * x; + const auto n_2_3 = (S) 2.64771686149306717256638234054408732899 + (S) -1.15385196641292513334184445301529897694 * x; + const auto n_0_1 = (S) 0.99999999999999995022522902211061062582 + (S) -2.6883902117841251600624689886592808124 * x; + const auto n_2_3_4_5 = n_2_3 + n_4_5 * x_sq; + const auto n_0_1_2_3_4_5 = n_0_1 + n_2_3_4_5 * x_sq; + + const auto d_5_6 = (S) -0.03980108270103465616851961097089502921 + (S) 0.00082742905522813187941384917520432493 * x; + const auto d_3_4 = (S) -1.70766499097900947314107956633154245176 + (S) 0.41595826557420951684124942212799147948 * x; + const auto d_1_2 = (S) -2.93839021178414636324893816529360171731 + (S) 3.27120330332951521662427278605230451458 * x; + const auto d_3_4_5_6 = d_3_4 + d_5_6 * x_sq; + const auto d_0_1_2 = (S) 1 + d_1_2 * x; + const auto d_0_1_2_3_4_5_6 = d_0_1_2 + d_3_4_5_6 * x_sq * x; + + return x * n_0_1_2_3_4_5 / d_0_1_2_3_4_5_6; + } + else + { + return {}; + } +} + +/** + * Approximation of the "dilogarithm" function for all inputs. + * + * Orders higher than 3 are generally not recommended for + * single-precision floating-point types, since they don't + * improve the accuracy very much. + */ +template = 5), typename T> +T li2 (T x) +{ + const auto x_r = (T) 1 / x; + const auto x_r1 = (T) 1 / (x - (T) 1); + + static constexpr auto pisq_o_6 = (T) M_PI * (T) M_PI / (T) 6; + static constexpr auto pisq_o_3 = (T) M_PI * (T) M_PI / (T) 3; + + T y, r; + bool sign = true; + if (x < (T) -1) + { + y = -x_r1; + const auto l = log ((T) 1 - x); + r = -pisq_o_6 + l * ((T) 0.5 * l - log (-x)); + } + else if (x < (T) 0) + { + y = x * x_r1; + const auto l = log ((T) 1 - x); + r = (T) -0.5 * l * l; + sign = false; + } + else if (x < (T) 0.5) + { + y = x; + r = {}; + } + else if (x < (T) 1) + { + y = (T) 1 - x; + r = pisq_o_6 - log (x) * log (y); + sign = false; + } + else if (x < (T) 2) + { + y = (T) 1 - x_r; + const auto l = log (x); + r = pisq_o_6 - l * (log (y) + (T) 0.5 * l); + } + else + { + y = x_r; + const auto l = log (x); + r = pisq_o_3 - (T) 0.5 * l * l; + sign = false; + } + + const auto li2_reduce = li2_0_half (y); + return r + select (sign, li2_reduce, -li2_reduce); +} + +/** + * Approximation of the "dilogarithm" function for all inputs. + * + * Orders higher than 3 are generally not recommended for + * single-precision floating-point types, since they don't + * improve the accuracy very much. + */ +template = 5), typename T> +xsimd::batch li2 (const xsimd::batch& x) +{ + // x < -1: + // - log(-x) -> [1, inf] + // - log(1-x) -> [2, inf] + // x < 0: + // - NOP + // - log(1-x) -> [1, 2] + // x < 1/2: + // - NOP + // - NOP + // x < 1: + // - log(x) -> [1/2, 1] + // - log(1-x) -> [0, 1/2] + // x < 2: + // - log(x) -> [1, 2] + // - log(1-1/x) -> [0, 1/2] + // x >= 2: + // - log(x) -> [2, inf] + // - NOP + + const auto x_r = (T) 1 / x; + const auto x_r1 = (T) 1 / (x - (T) 1); + const auto log_arg1 = select (x < (T) -1, -x, select (x < (T) 0.5, xsimd::broadcast ((T) 1), x)); + const auto log_arg2 = select (x < (T) 1, (T) 1 - x, (T) 1 - x_r); + + const auto log1 = log (log_arg1); + const auto log2 = log (log_arg2); + + // clang-format off + const auto y = select (x < (T) -1, (T) -1 * x_r1, + select (x < (T) 0, x * x_r1, + select (x < (T) 0.5, x, + select (x < (T) 1, (T) 1 - x, + select (x < (T) 2, (T) 1 - x_r, + x_r))))); + const auto sign = x < (T) -1 || (x >= (T) 0 && x < (T) 0.5) || (x >= (T) 1 && x < (T) 2); + + static constexpr auto pisq_o_6 = (T) M_PI * (T) M_PI / (T) 6; + static constexpr auto pisq_o_3 = (T) M_PI * (T) M_PI / (T) 3; + const auto log1_log2 = log1 * log2; + const auto half_log1_sq = (T) 0.5 * log1 * log1; + const auto half_log2_sq = (T) 0.5 * log2 * log2; + const auto r = select (x < (T) -1, -pisq_o_6 + half_log2_sq - log1_log2, + select (x < (T) 0, -half_log2_sq, + select (x < (T) 0.5, xsimd::broadcast ((T) 0), + select (x < (T) 1, pisq_o_6 - log1_log2, + select (x < (T) 2, pisq_o_6 - log1_log2 - half_log1_sq, + pisq_o_3 - half_log1_sq))))); + //clang-format on + + const auto li2_reduce = li2_0_half (y); + return r + select (sign, li2_reduce, -li2_reduce); +} +} // namespace math_approx diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8dc0101..fff15d2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,3 +34,4 @@ setup_catch_test(trig_approx_test) setup_catch_test(pow_approx_test) setup_catch_test(log_approx_test) setup_catch_test(wright_omega_approx_test) +setup_catch_test(polylog_approx_test) diff --git a/test/src/polylog_approx_test.cpp b/test/src/polylog_approx_test.cpp new file mode 100644 index 0000000..d3feb92 --- /dev/null +++ b/test/src/polylog_approx_test.cpp @@ -0,0 +1,70 @@ +#include "test_helpers.hpp" +#include +#include + +#include + +#include "reference/polylogarithm.hpp" + +TEST_CASE ("Li2 Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-2f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return polylogarithm::Li2 (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto rel_error = test_helpers::compute_rel_error (y_exact, y_approx); + const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx); + + const auto max_error = test_helpers::abs_max (error); + const auto max_rel_error = test_helpers::abs_max (rel_error); + const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end()); + + std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + REQUIRE (std::abs (max_rel_error) < rel_err_bound); + if (ulp_bound > 0) + REQUIRE (max_ulp_error < ulp_bound); + }; + + SECTION ("3rd-Order_Log-6") + { + test_approx ([] (auto x) + { return math_approx::li2<3, 6> (x); }, + 2.5e-5f, + 1.5e-5f, + 200); + } + SECTION ("3rd-Order") + { + test_approx ([] (auto x) + { return math_approx::li2<3> (x); }, + 8.0e-5f, + 1.5e-4f, + 0); + } + SECTION ("2nd-Order") + { + test_approx ([] (auto x) + { return math_approx::li2<2> (x); }, + 3.0e-4f, + 3.0e-4f, + 0); + } + SECTION ("1st-Order") + { + test_approx ([] (auto x) + { return math_approx::li2<1> (x); }, + 2.5e-3f, + 4.0e-3f, + 0); + } +} diff --git a/test/src/reference/polylogarithm.hpp b/test/src/reference/polylogarithm.hpp new file mode 100644 index 0000000..dadb219 --- /dev/null +++ b/test/src/reference/polylogarithm.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include + +/** + * Implementations of polylogarithm functions. + * + * Based on the implementations found at: https://github.com/Expander/polylogarithm + */ +namespace polylogarithm +{ + /** real polylogarithm with n=2 (dilogarithm). */ + template + inline T Li2 (T x) noexcept + { + constexpr auto PI_ = static_cast (M_PI); + constexpr T P[] = { + (T) 0.9999999999999999502e+0, + (T) -2.6883926818565423430e+0, + (T) 2.6477222699473109692e+0, + (T) -1.1538559607887416355e+0, + (T) 2.0886077795020607837e-1, + (T) -1.0859777134152463084e-2 + }; + constexpr T Q[] = { + (T) 1.0000000000000000000e+0, + (T) -2.9383926818565635485e+0, + (T) 3.2712093293018635389e+0, + (T) -1.7076702173954289421e+0, + (T) 4.1596017228400603836e-1, + (T) -3.9801343754084482956e-2, + (T) 8.2743668974466659035e-4 + }; + + T y = 0, r = 0, s = 1; + + // transform to [0, 1/2] + if (x < (T) -1) + { + const auto l = std::log ((T) 1 - x); + y = (T) 1 / ((T) 1 - x); + r = -PI_ * PI_ / (T) 6 + l * ((T) 0.5 * l - std::log (-x)); + s = (T) 1; + } + else if (x == (T) -1) + { + return -PI_ * PI_ / (T) 12; + } + else if (x < (T) 0) + { + const auto l = std::log1p (-x); + y = x / (x - (T) 1); + r = (T) -0.5 * l * l; + s = (T) -1; + } + else if (x == (T) 0) + { + return (T) 0; + } + else if (x < (T) 0.5) + { + y = x; + r = (T) 0; + s = (T) 1; + } + else if (x < (T) 1) + { + y = (T) 1 - x; + r = PI_ * PI_ / (T) 6 - std::log (x) * std::log (y); + s = (T) -1; + } + else if (x == (T) 1) + { + return PI_ * PI_ / (T) 6; + } + else if (x < (T) 2) + { + const auto l = std::log (x); + y = (T) 1 - (T) 1 / x; + r = PI_ * PI_ / (T) 6 - l * (std::log (y) + (T) 0.5 * l); + s = (T) 1; + } + else + { + const auto l = std::log (x); + y = (T) 1 / x; + r = PI_ * PI_ / (T) 3 - (T) 0.5 * l * l; + s = (T) -1; + } + + const auto y2 = y * y; + const auto y4 = y2 * y2; + const auto p = P[0] + y * P[1] + y2 * (P[2] + y * P[3]) + y4 * (P[4] + y * P[5]); + const auto q = Q[0] + y * Q[1] + y2 * (Q[2] + y * Q[3]) + y4 * (Q[4] + y * Q[5] + y2 * Q[6]); + + return r + s * y * p / q; + } +} // namespace polylogarithm diff --git a/tools/bench/CMakeLists.txt b/tools/bench/CMakeLists.txt index c96dff6..0c06c5d 100644 --- a/tools/bench/CMakeLists.txt +++ b/tools/bench/CMakeLists.txt @@ -22,3 +22,6 @@ target_link_libraries(log_approx_bench PRIVATE benchmark::benchmark math_approx) add_executable(wright_omega_approx_bench wright_omega_bench.cpp) target_link_libraries(wright_omega_approx_bench PRIVATE benchmark::benchmark math_approx) + +add_executable(polylog_approx_bench polylog_bench.cpp) +target_link_libraries(polylog_approx_bench PRIVATE benchmark::benchmark math_approx) diff --git a/tools/bench/polylog_bench.cpp b/tools/bench/polylog_bench.cpp new file mode 100644 index 0000000..48e1bf8 --- /dev/null +++ b/tools/bench/polylog_bench.cpp @@ -0,0 +1,53 @@ +#include +#include +#include "../test/src/reference/polylogarithm.hpp" + +static constexpr size_t N = 2000; +const auto data = [] +{ + std::vector x; + x.resize (N, 0.0f); + for (size_t i = 0; i < N; ++i) + x[i] = -10.0f + 20.0f * (float) i / (float) N; + return x; +}(); + +#define POLYLOG_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (x); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); +POLYLOG_BENCH (li2_ref, polylogarithm::Li2) +POLYLOG_BENCH (li2_approx3_log6, (math_approx::li2<3,6>)) +POLYLOG_BENCH (li2_approx3, math_approx::li2<3>) +POLYLOG_BENCH (li2_approx2, math_approx::li2<2>) +POLYLOG_BENCH (li2_approx1, math_approx::li2<1>) + +#define POLYLOG_SIMD_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (xsimd::broadcast (x)); \ +static_assert (std::is_same_v, decltype(y)>); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); +POLYLOG_SIMD_BENCH (li2_simd_approx3_log6, (math_approx::li2<3,6>)) +POLYLOG_SIMD_BENCH (li2_simd_approx3, math_approx::li2<3>) +POLYLOG_SIMD_BENCH (li2_simd_approx2, math_approx::li2<2>) +POLYLOG_SIMD_BENCH (li2_simd_approx1, math_approx::li2<1>) + +BENCHMARK_MAIN(); diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index 9d25d44..927da1b 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -8,7 +8,7 @@ namespace plt = matplotlibcpp; #include "../../test/src/test_helpers.hpp" #include "../../test/src/reference/toms917.hpp" -#include "../../test/src/reference/dangelo_omega.hpp" +#include "../../test/src/reference/polylogarithm.hpp" #include template @@ -61,18 +61,12 @@ void plot_function (std::span all_floats, int main() { plt::figure(); - const auto range = std::make_pair (-10.0f, 30.0f); - static constexpr auto tol = 1.0e-1f; + const auto range = std::make_pair (-5.0f, 5.0f); + static constexpr auto tol = 1.0e-2f; const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); - const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC(toms917::wrightomega)); - - // plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<0>)), "W-O 0-3"); - // plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<0, 5>)), "W-O 0-5"); - plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<1>)), "W-O 1-3"); - // plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega_dangelo<0>)), "W-O D'Angelo 0"); - plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega_dangelo<1>)), "W-O D'Angelo 1"); - plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega_dangelo<2>)), "W-O D'Angelo 2"); + const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC(polylogarithm::Li2)); + plot_ulp_error (all_floats, y_exact, FLOAT_FUNC((math_approx::li2<3,6>)), "Li2-3"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);