From 449a1c5339945b52d1e19f62eff52a05c665ce8c Mon Sep 17 00:00:00 2001 From: jatin Date: Thu, 7 Dec 2023 20:06:02 -0800 Subject: [PATCH] Add asinh implementation --- include/math_approx/math_approx.hpp | 1 + include/math_approx/src/asinh_approx.hpp | 79 +++++++++++ include/math_approx/src/log_approx.hpp | 169 ++++++++++++----------- test/CMakeLists.txt | 1 + test/src/asinh_approx_test.cpp | 54 ++++++++ tools/bench/CMakeLists.txt | 3 + tools/bench/asinh_bench.cpp | 55 ++++++++ tools/plotter/plotter.cpp | 13 +- 8 files changed, 286 insertions(+), 89 deletions(-) create mode 100644 include/math_approx/src/asinh_approx.hpp create mode 100644 test/src/asinh_approx_test.cpp create mode 100644 tools/bench/asinh_bench.cpp diff --git a/include/math_approx/math_approx.hpp b/include/math_approx/math_approx.hpp index aa1b396..a1d49b0 100644 --- a/include/math_approx/math_approx.hpp +++ b/include/math_approx/math_approx.hpp @@ -11,6 +11,7 @@ namespace math_approx #include "src/log_approx.hpp" #include "src/tanh_approx.hpp" #include "src/sinh_cosh_approx.hpp" +#include "src/asinh_approx.hpp" #include "src/sigmoid_approx.hpp" #include "src/wright_omega_approx.hpp" #include "src/polylogarithm_approx.hpp" diff --git a/include/math_approx/src/asinh_approx.hpp b/include/math_approx/src/asinh_approx.hpp new file mode 100644 index 0000000..bcc67a6 --- /dev/null +++ b/include/math_approx/src/asinh_approx.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include "log_approx.hpp" + +namespace math_approx +{ +struct AsinhLog2Provider +{ + /** approximation for log2(x), optimized on the range [1, 2], to be used within an asinh(x) computation */ + template + static constexpr T log2_approx (T x) + { + static_assert (order >= 3 && order <= 5); + using S = scalar_of_t; + + const auto x_sq = x * x; + if constexpr (order == 3) + { + const auto x_2_3 = (S) -1.21535595794871 + (S) 0.194363894384581 * x; + const auto x_0_1 = (S) -2.26452854958994 + (S) 3.28552061315407 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.770443387059628 + (S) -0.102652345633016 * x; + const auto x_1_2 = (S) 4.33013912645867 + (S) -2.39448588379361 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) -2.60344428409168 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) -0.511946284688366 + (S) 0.0578217518982235 * x; + const auto x_2_3 = (S) -3.94632584968643 + (S) 1.90796087279737 * x; + const auto x_0_1 = (S) -2.87748189127908 + (S) 5.36997140095829 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else + { + return {}; + } + } +}; + +/** + * Approximation of asinh(x) in the full range, using identity + * asinh(x) = log(x + sqrt(x^2 + 1)). + * + * Orders 6 and 7 use an additional Newton-Raphson iteration, + * but for most cases the accuracy improvement is not worth + * the additional cost (when compared to the performance and + * accuracy achieved by the STL implementation). + */ +template +T asinh (T x) +{ + using S = scalar_of_t; + using std::abs; + using std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs; + using xsimd::sqrt; +#endif + + const auto sign = select (x > (S) 0, (T) (S) 1, select (x < (S) 0, (T) (S) -1, (T) (S) 0)); + x = abs (x); + + const auto log_arg = x + sqrt (x * x + (S) 1); + auto y = log>, std::min (order, 5), false, AsinhLog2Provider> (log_arg); + + if constexpr (order > 5) + { + const auto exp_y = math_approx::exp (y); + y -= (exp_y - log_arg) / exp_y; + } + + return sign * y; +} +} // namespace math_approx diff --git a/include/math_approx/src/log_approx.hpp b/include/math_approx/src/log_approx.hpp index 8571638..0f19b3c 100644 --- a/include/math_approx/src/log_approx.hpp +++ b/include/math_approx/src/log_approx.hpp @@ -7,89 +7,92 @@ namespace math_approx { namespace log_detail { - /** approximation for log2(x), optimized on the range [1, 2] */ - template - constexpr T log2_approx (T x) + struct Log2Provider { - static_assert (order >= 3 && order <= 6); - using S = scalar_of_t; - - const auto x_sq = x * x; - if constexpr (C1_continuous) + /** approximation for log2(x), optimized on the range [1, 2] */ + template + static constexpr T log2_approx (T x) { - if constexpr (order == 3) - { - const auto x_2_3 = (S) -1.09886528622 + (S) 0.164042561333 * x; - const auto x_0_1 = (S) -2.21347520444 + (S) 3.14829792933 * x; - return x_0_1 + x_2_3 * x_sq; - } - else if constexpr (order == 4) - { - const auto x_3_4 = (S) 0.671618567027 + (S) -0.0845960009489 * x; - const auto x_1_2 = (S) 4.16344994072 + (S) -2.19861329856 * x; - const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; - return (S) -2.55185920824 + x_1_2_3_4 * x; - } - else if constexpr (order == 5) - { - const auto x_4_5 = (S) -0.432338320780 + (S) 0.0464481811023 * x; - const auto x_2_3 = (S) -3.65368350361 + (S) 1.68976432066 * x; - const auto x_0_1 = (S) -2.82807214111 + (S) 5.17788146374 * x; - const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; - return x_0_1 + x_2_3_4_5 * x_sq; - } - else if constexpr (order == 6) - { - const auto x_5_6 = (S) 0.284794437502 + (S) -0.0265448504094 * x; - const auto x_3_4 = (S) 3.38542517475 + (S) -1.31007090775 * x; - const auto x_1_2 = (S) 6.19242937536 + (S) -5.46521465640 * x; - const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; - const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; - return (S) -3.06081857306 + x_1_2_3_4_5_6 * x; - } - else - { - return {}; - } - } - else - { - if constexpr (order == 3) - { - const auto x_2_3 = (S) -1.05974531422 + (S) 0.159220010975 * x; - const auto x_0_1 = (S) -2.16417056258 + (S) 3.06469586582 * x; - return x_0_1 + x_2_3 * x_sq; - } - else if constexpr (order == 4) - { - const auto x_3_4 = (S) 0.649709537672 + (S) -0.0821303550902 * x; - const auto x_1_2 = (S) 4.08637809379 + (S) -2.13412984371 * x; - const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; - return (S) -2.51982743265 + x_1_2_3_4 * x; - } - else if constexpr (order == 5) - { - const auto x_4_5 = (S) -0.419319345483 + (S) 0.0451488402558 * x; - const auto x_2_3 = (S) -3.56885211615 + (S) 1.64139451414 * x; - const auto x_0_1 = (S) -2.80534277658 + (S) 5.10697088382 * x; - const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; - return x_0_1 + x_2_3_4_5 * x_sq; - } - else if constexpr (order == 6) + static_assert (order >= 3 && order <= 6); + using S = scalar_of_t; + + const auto x_sq = x * x; + if constexpr (C1_continuous) { - const auto x_5_6 = (S) 0.276834061071 + (S) -0.0258400886535 * x; - const auto x_3_4 = (S) 3.30388341157 + (S) -1.27446900713 * x; - const auto x_1_2 = (S) 6.12708086513 + (S) -5.36371998242 * x; - const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; - const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; - return (S) -3.04376925958 + x_1_2_3_4_5_6 * x; + if constexpr (order == 3) + { + const auto x_2_3 = (S) -1.09886528622 + (S) 0.164042561333 * x; + const auto x_0_1 = (S) -2.21347520444 + (S) 3.14829792933 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.671618567027 + (S) -0.0845960009489 * x; + const auto x_1_2 = (S) 4.16344994072 + (S) -2.19861329856 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) -2.55185920824 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) -0.432338320780 + (S) 0.0464481811023 * x; + const auto x_2_3 = (S) -3.65368350361 + (S) 1.68976432066 * x; + const auto x_0_1 = (S) -2.82807214111 + (S) 5.17788146374 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else if constexpr (order == 6) + { + const auto x_5_6 = (S) 0.284794437502 + (S) -0.0265448504094 * x; + const auto x_3_4 = (S) 3.38542517475 + (S) -1.31007090775 * x; + const auto x_1_2 = (S) 6.19242937536 + (S) -5.46521465640 * x; + const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; + const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; + return (S) -3.06081857306 + x_1_2_3_4_5_6 * x; + } + else + { + return {}; + } } else { - return {}; + if constexpr (order == 3) + { + const auto x_2_3 = (S) -1.05974531422 + (S) 0.159220010975 * x; + const auto x_0_1 = (S) -2.16417056258 + (S) 3.06469586582 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.649709537672 + (S) -0.0821303550902 * x; + const auto x_1_2 = (S) 4.08637809379 + (S) -2.13412984371 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) -2.51982743265 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) -0.419319345483 + (S) 0.0451488402558 * x; + const auto x_2_3 = (S) -3.56885211615 + (S) 1.64139451414 * x; + const auto x_0_1 = (S) -2.80534277658 + (S) 5.10697088382 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else if constexpr (order == 6) + { + const auto x_5_6 = (S) 0.276834061071 + (S) -0.0258400886535 * x; + const auto x_3_4 = (S) 3.30388341157 + (S) -1.27446900713 * x; + const auto x_1_2 = (S) 6.12708086513 + (S) -5.36371998242 * x; + const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; + const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; + return (S) -3.04376925958 + x_1_2_3_4_5_6 * x; + } + else + { + return {}; + } } } - } + }; } #if defined(__GNUC__) @@ -99,7 +102,7 @@ namespace log_detail #endif /** approximation for log(Base, x) (32-bit) */ -template +template float log (float x) { const auto vi = reinterpret_cast (x); @@ -109,11 +112,11 @@ float log (float x) const auto vf = reinterpret_cast (vfi); static constexpr auto log2_base_r = 1.0f / Base::log2_base; - return log2_base_r * ((float) e + log_detail::log2_approx (vf)); + return log2_base_r * ((float) e + Log2ProviderType::template log2_approx (vf)); } /** approximation for log(x) (64-bit) */ -template +template double log (double x) { const auto vi = reinterpret_cast (x); @@ -123,12 +126,12 @@ double log (double x) const auto vf = reinterpret_cast (vfi); static constexpr auto log2_base_r = 1.0 / Base::log2_base; - return log2_base_r * ((double) e + log_detail::log2_approx (vf)); + return log2_base_r * ((double) e + Log2ProviderType::template log2_approx (vf)); } #if defined(XSIMD_HPP) /** approximation for pow(Base, x) (32-bit SIMD) */ -template +template xsimd::batch log (xsimd::batch x) { const auto vi = reinterpret_cast&> (x); // NOSONAR @@ -138,11 +141,11 @@ xsimd::batch log (xsimd::batch x) const auto vf = reinterpret_cast&> (vfi); // NOSONAR static constexpr auto log2_base_r = 1.0f / Base::log2_base; - return log2_base_r * (xsimd::to_float (e) + log_detail::log2_approx, order, C1_continuous> (vf)); + return log2_base_r * (xsimd::to_float (e) + Log2ProviderType::template log2_approx, order, C1_continuous> (vf)); } /** approximation for pow(Base, x) (64-bit SIMD) */ -template +template xsimd::batch log (xsimd::batch x) { const auto vi = reinterpret_cast&> (x); // NOSONAR @@ -152,7 +155,7 @@ xsimd::batch log (xsimd::batch x) const auto vf = reinterpret_cast&> (vfi); // NOSONAR static constexpr auto log2_base_r = 1.0 / Base::log2_base; - return log2_base_r * (xsimd::to_float (e) + log_detail::log2_approx, order, C1_continuous> (vf)); + return log2_base_r * (xsimd::to_float (e) + Log2ProviderType::template log2_approx, order, C1_continuous> (vf)); } #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 04d059f..a1221c7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,3 +36,4 @@ setup_catch_test(log_approx_test) setup_catch_test(wright_omega_approx_test) setup_catch_test(polylog_approx_test) setup_catch_test(sinh_cosh_approx_test) +setup_catch_test(asinh_approx_test) diff --git a/test/src/asinh_approx_test.cpp b/test/src/asinh_approx_test.cpp new file mode 100644 index 0000000..53320c4 --- /dev/null +++ b/test/src/asinh_approx_test.cpp @@ -0,0 +1,54 @@ +#include "test_helpers.hpp" +#include "catch2/catch_template_test_macros.hpp" + +#include +#include + +#include + +template +void test_approx (const auto& all_floats, const auto& y_exact, auto&& f_approx, float err_bound) +{ + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto max_error = test_helpers::abs_max (error); + + std::cout << max_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); +} + +TEMPLATE_TEST_CASE ("Asinh Approx Test", "", float, double) +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-2f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::asinh (x); }); + + SECTION ("6th-Order") + { + test_approx (all_floats, y_exact, [] (auto x) + { return math_approx::asinh<6> (x); }, + 5.0e-7f); + } + SECTION ("5th-Order") + { + test_approx (all_floats, y_exact, [] (auto x) + { return math_approx::asinh<5> (x); }, + 6.0e-5f); + } + SECTION ("4th-Order") + { + test_approx (all_floats, y_exact, [] (auto x) + { return math_approx::asinh<4> (x); }, + 3.5e-4f); + } + SECTION ("3th-Order") + { + test_approx (all_floats, y_exact, [] (auto x) + { return math_approx::asinh<3> (x); }, + 2.5e-3f); + } +} diff --git a/tools/bench/CMakeLists.txt b/tools/bench/CMakeLists.txt index 51df02a..dab6795 100644 --- a/tools/bench/CMakeLists.txt +++ b/tools/bench/CMakeLists.txt @@ -28,3 +28,6 @@ target_link_libraries(polylog_approx_bench PRIVATE benchmark::benchmark math_app add_executable(sinh_cosh_approx_bench sinh_cosh_bench.cpp) target_link_libraries(sinh_cosh_approx_bench PRIVATE benchmark::benchmark math_approx) + +add_executable(asinh_approx_bench asinh_bench.cpp) +target_link_libraries(asinh_approx_bench PRIVATE benchmark::benchmark math_approx) diff --git a/tools/bench/asinh_bench.cpp b/tools/bench/asinh_bench.cpp new file mode 100644 index 0000000..31c7685 --- /dev/null +++ b/tools/bench/asinh_bench.cpp @@ -0,0 +1,55 @@ +#include +#include + +static constexpr size_t N = 2000; +const auto data = [] +{ + std::vector x; + x.resize (N, 0.0f); + for (size_t i = 0; i < N; ++i) + x[i] = -10.0f + 20.0f * (float) i / (float) N; + return x; +}(); + +#define ASINH_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (x); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); +ASINH_BENCH (asinh_std, std::asinh) +ASINH_BENCH (asinh_approx7, math_approx::asinh<7>) +ASINH_BENCH (asinh_approx6, math_approx::asinh<6>) +ASINH_BENCH (asinh_approx5, math_approx::asinh<5>) +ASINH_BENCH (asinh_approx4, math_approx::asinh<4>) +ASINH_BENCH (asinh_approx3, math_approx::asinh<3>) + +#define ASINH_SIMD_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (xsimd::broadcast (x)); \ +static_assert (std::is_same_v, decltype(y)>); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); +ASINH_SIMD_BENCH (asinh_xsimd, xsimd::asinh) +ASINH_SIMD_BENCH (asinh_simd_approx7, math_approx::asinh<7>) +ASINH_SIMD_BENCH (asinh_simd_approx6, math_approx::asinh<6>) +ASINH_SIMD_BENCH (asinh_simd_approx5, math_approx::asinh<5>) +ASINH_SIMD_BENCH (asinh_simd_approx4, math_approx::asinh<4>) +ASINH_SIMD_BENCH (asinh_simd_approx3, math_approx::asinh<3>) + +BENCHMARK_MAIN(); diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index 53cad1c..b1acf88 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -6,9 +6,9 @@ #include namespace plt = matplotlibcpp; -#include "../../test/src/test_helpers.hpp" -#include "../../test/src/reference/toms917.hpp" #include "../../test/src/reference/polylogarithm.hpp" +#include "../../test/src/reference/toms917.hpp" +#include "../../test/src/test_helpers.hpp" #include template @@ -61,13 +61,14 @@ void plot_function (std::span all_floats, int main() { plt::figure(); - const auto range = std::make_pair (-3.0f, 3.0f); + const auto range = std::make_pair (-1.0f, 1.0f); static constexpr auto tol = 1.0e-2f; const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); - const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC(std::cosh)); - plot_ulp_error (all_floats, y_exact, FLOAT_FUNC((math_approx::cosh<5>)), "cosh-5"); - plot_ulp_error (all_floats, y_exact, FLOAT_FUNC((math_approx::cosh<6>)), "cosh-6"); + const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (std::asinh)); + // plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asinh<5>)), "asinh-5"); + plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asinh<6>) ), "asinh-6"); + plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asinh<7>) ), "asinh-7"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);