diff --git a/include/math_approx/math_approx.hpp b/include/math_approx/math_approx.hpp index a1d49b0..03aa4e0 100644 --- a/include/math_approx/math_approx.hpp +++ b/include/math_approx/math_approx.hpp @@ -7,6 +7,7 @@ namespace math_approx #include "src/basic_math.hpp" #include "src/trig_approx.hpp" +#include "src/inverse_trig_approx.hpp" #include "src/pow_approx.hpp" #include "src/log_approx.hpp" #include "src/tanh_approx.hpp" diff --git a/include/math_approx/src/inverse_trig_approx.hpp b/include/math_approx/src/inverse_trig_approx.hpp new file mode 100644 index 0000000..89e6c56 --- /dev/null +++ b/include/math_approx/src/inverse_trig_approx.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +namespace asin_detail +{ + template + constexpr T asin_0_rsqrt2 (T x) + { + using S = scalar_of_t; + static_assert (order >= 3 && order <= 11); + + const auto x_sq = x * x; + if constexpr (order == 3) + { + const auto y_2_3 = (S) -0.0536922932754174 + (S) 0.297373838424192 * x; + return x + x_sq * y_2_3; + } + if constexpr (order == 4) + { + const auto y_3_4 = (S) -0.00535062837316264 + (S) 0.257252341545375 * x; + const auto y_2_3_4 = (S) 0.0317400592553864 + x * y_3_4; + return x + x_sq * y_2_3_4; + } + if constexpr (order == 5) + { + const auto y_4_5 = (S) -0.304640601352515 + (S) 0.353208342056560 * x; + const auto y_2_3 = (S) -0.0132122795426018 + (S) 0.278935718011026 * x; + + const auto y_2_3_4_5 = y_2_3 + x_sq * y_4_5; + return x + x_sq * y_2_3_4_5; + } + if constexpr (order == 6) + { + const auto y_5_6 = (S) -0.516068582317285 + (S) 0.437544978265334 * x; + const auto y_3_4 = (S) 0.0946375652126262 + (S) 0.313911974469437 * x; + + const auto y_3_4_5_6 = y_3_4 + x_sq * y_5_6; + + const auto y_2_3_4_5_6 = (S) 0.00577946556085762 + x * y_3_4_5_6; + return x + x_sq * y_2_3_4_5_6; + } + if constexpr (order == 7) + { + const auto y_6_7 = (S) -0.993023225129115 + (S) 0.604213345541030 * x; + const auto y_4_5 = (S) -0.242568404368623 + (S) 0.780715776826480 * x; + const auto y_2_3 = (S) -0.00231743294930714 + (S) 0.205916081200406 * x; + + const auto y_4_5_6_7 = y_4_5 + x_sq * y_6_7; + + const auto y_2_3_4_5_6_7 = y_2_3 + x_sq * y_4_5_6_7; + return x + x_sq * y_2_3_4_5_6_7; + } + if constexpr (order == 8) + { + const auto y_7_8 = (S) -1.68181413527251 + (S) 0.833569228384441 * x; + const auto y_5_6 = (S) -0.614138628435564 + (S) 1.51390471735914 * x; + const auto y_3_4 = (S) 0.146440161696543 + (S) 0.167766328527588 * x; + + const auto y_5_6_7_8 = y_5_6 + x_sq * y_7_8; + const auto y_3_4_5_6_7_8 = y_3_4 + x_sq * y_5_6_7_8; + + const auto y_2_3_4_5_6_7_8 = (S) 0.000914775210828589 + x * y_3_4_5_6_7_8; + return x + x_sq * y_2_3_4_5_6_7_8; + } + if constexpr (order == 9) + { + const auto y_8_9 = (S) -2.8729113246543627191 + (S) 1.1910880616677141930 * x; + const auto y_6_7 = (S) -1.7250043993765906691 + (S) 3.0742940024198017746 * x; + const auto y_4_5 = (S) -0.10358468520191396745 + (S) 0.63814601829123507315 * x; + const auto y_2_3 = (S) -0.00034869418257434217938 + (S) 0.17639243703620430259 * x; + + const auto y_6_7_8_9 = y_6_7 + x_sq * y_8_9; + const auto y_2_3_4_5 = y_2_3 + x_sq * y_4_5; + + const auto y_2_3_4_5_6_7_8_9 = y_2_3_4_5 + (x_sq * x_sq) * y_6_7_8_9; + return x + x_sq * y_2_3_4_5_6_7_8_9; + } + if constexpr (order == 10) + { + const auto y_9_10 = (S) -4.7928604989214971255 + (S) 1.7203396621648587850 * x; + const auto y_7_8 = (S) -3.9882416242478972990 + (S) 5.9100541437570059955 * x; + const auto y_5_6 = (S) -0.33522250091818628359 + (S) 1.6520149306717599735 * x; + const auto y_3_4 = (S) 0.16219681678629885302 + (S) 0.059321118077451009953 * x; + + const auto y_7_8_9_10 = y_7_8 + x_sq * y_9_10; + const auto y_3_4_5_6 = y_3_4 + x_sq * y_5_6; + + const auto y_3_4_5_6_7_8_9_10 = y_3_4_5_6 + (x_sq * x_sq) * y_7_8_9_10; + const auto y_2_3_4_5_6_7_8_9_10 = (S) 0.00013025301412532343010 + x * y_3_4_5_6_7_8_9_10; + return x + x_sq * y_2_3_4_5_6_7_8_9_10; + } + if constexpr (order == 11) + { + const auto y_10_11 = (S) -7.8570177355488999282 + (S) 2.4911046380177723769 * x; + const auto y_8_9 = (S) -8.7527722722991097015 + (S) 11.027490029915364644 * x; + const auto y_6_7 = (S) -1.3541397019129590706 + (S) 4.3623048430828985644 * x; + const auto y_4_5 = (S) -0.031316768977302081312 + (S) 0.34169372825309551889 * x; + const auto y_2_3 = (S) -0.000047491279899706856284 + (S) 0.16861521810527382859 * x; + + const auto y_8_9_10_11 = y_8_9 + x_sq * y_10_11; + const auto y_4_5_6_7 = y_4_5 + x_sq * y_6_7; + + const auto y_4_5_6_7_8_9_10_11 = y_4_5_6_7 + (x_sq * x_sq) * y_8_9_10_11; + const auto y_2_3_4_5_6_7_8_9_10_11 = y_2_3 + x_sq * y_4_5_6_7_8_9_10_11; + + return x + x_sq * y_2_3_4_5_6_7_8_9_10_11; + } + return {}; + } +} // namespace asin_detail + +/** Asin(x) approximation, valid on the range [-1, 1] */ +template +T asin (T x) +{ + using S = scalar_of_t; + static constexpr auto rsqrt2 = (S) 0.707106781186547524400844362105; + + using std::abs, std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs, xsimd::sqrt; +#endif + const auto abs_x = abs (x); + + const auto reflect = abs_x > rsqrt2; + const auto poly_arg = select (reflect, sqrt ((S) 1 - abs_x * abs_x), abs_x); + const auto poly_res = asin_detail::asin_0_rsqrt2 (poly_arg); + const auto res = select (reflect, (S) M_PI_2 - poly_res, poly_res); + + return select (x > (S) 0, res, -res); +} +} // namespace math_approx diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a1221c7..4f4420a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -31,6 +31,7 @@ endfunction(setup_catch_test) setup_catch_test(tanh_approx_test) setup_catch_test(sigmoid_approx_test) setup_catch_test(trig_approx_test) +setup_catch_test(inverse_trig_approx_test) setup_catch_test(pow_approx_test) setup_catch_test(log_approx_test) setup_catch_test(wright_omega_approx_test) diff --git a/test/src/inverse_trig_approx_test.cpp b/test/src/inverse_trig_approx_test.cpp new file mode 100644 index 0000000..3f30786 --- /dev/null +++ b/test/src/inverse_trig_approx_test.cpp @@ -0,0 +1,108 @@ +#include "test_helpers.hpp" +#include +#include + +#include + +TEST_CASE ("Asin Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-1.0f, 1.0f, 1.0e-2f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::asin (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto rel_error = test_helpers::compute_rel_error (y_exact, y_approx); + const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx); + + const auto max_error = test_helpers::abs_max (error); + const auto max_rel_error = test_helpers::abs_max (rel_error); + const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end()); + + std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + REQUIRE (std::abs (max_rel_error) < rel_err_bound); + if (ulp_bound > 0) + REQUIRE (max_ulp_error < ulp_bound); + }; + + SECTION ("11th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<11> (x); }, + 2.5e-7f, + 5.0e-7f, + 8); + } + SECTION ("10th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<10> (x); }, + 2.5e-7f, + 1.5e-6f, + 22); + } + SECTION ("9th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<9> (x); }, + 4.0e-7f, + 4.5e-6f, + 72); + } + SECTION ("8th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<8> (x); }, + 8.0e-7f, + 1.5e-5f, + 250); + } + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<7> (x); }, + 3.0e-6f, + 5.0e-5f, + 0); + } + SECTION ("6th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<6> (x); }, + 1.5e-5f, + 2.0e-4f, + 0); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<5> (x); }, + 5.5e-5f, + 5.0e-4f, + 0); + } + SECTION ("4th-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<4> (x); }, + 3.0e-4f, + 2.0e-3f, + 0); + } + SECTION ("3rd-Order") + { + test_approx ([] (auto x) + { return math_approx::asin<3> (x); }, + 2.0e-3f, + 6.0e-3f, + 0); + } +} diff --git a/tools/bench/CMakeLists.txt b/tools/bench/CMakeLists.txt index 1be164e..5b5b279 100644 --- a/tools/bench/CMakeLists.txt +++ b/tools/bench/CMakeLists.txt @@ -19,6 +19,9 @@ target_link_libraries(sigmoid_approx_bench PRIVATE benchmark::benchmark math_app add_executable(trig_approx_bench trig_bench.cpp) target_link_libraries(trig_approx_bench PRIVATE benchmark::benchmark math_approx) +add_executable(inverse_trig_approx_bench inverse_trig_bench.cpp) +target_link_libraries(inverse_trig_approx_bench PRIVATE benchmark::benchmark math_approx) + add_executable(pow_approx_bench pow_bench.cpp) target_link_libraries(pow_approx_bench PRIVATE benchmark::benchmark math_approx) diff --git a/tools/bench/inverse_trig_bench.cpp b/tools/bench/inverse_trig_bench.cpp new file mode 100644 index 0000000..e403a9a --- /dev/null +++ b/tools/bench/inverse_trig_bench.cpp @@ -0,0 +1,65 @@ +#include +#include + +static constexpr size_t N = 1000; +const auto data = [] +{ + std::vector x; + x.resize (N, 0.0f); + for (size_t i = 0; i < N; ++i) + x[i] = -1.0f + 2.0f * (float) i / (float) N; + return x; +}(); + +#define TRIG_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (x); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); + +// TRIG_BENCH (asin_std, std::asin) +// TRIG_BENCH (asin_approx11, math_approx::asin<11>) +// TRIG_BENCH (asin_approx10, math_approx::asin<10>) +// TRIG_BENCH (asin_approx9, math_approx::asin<9>) +// TRIG_BENCH (asin_approx8, math_approx::asin<8>) +// TRIG_BENCH (asin_approx7, math_approx::asin<7>) +// TRIG_BENCH (asin_approx6, math_approx::asin<6>) +// TRIG_BENCH (asin_approx5, math_approx::asin<5>) +// TRIG_BENCH (asin_approx4, math_approx::asin<4>) +// TRIG_BENCH (asin_approx3, math_approx::asin<3>) + +#define TRIG_SIMD_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (xsimd::broadcast (x)); \ +static_assert (std::is_same_v, decltype(y)>); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); + +TRIG_SIMD_BENCH (asin_xsimd, xsimd::asin) +TRIG_SIMD_BENCH (asin_simd_approx11, math_approx::asin<11>) +TRIG_SIMD_BENCH (asin_simd_approx10, math_approx::asin<10>) +TRIG_SIMD_BENCH (asin_simd_approx9, math_approx::asin<9>) +TRIG_SIMD_BENCH (asin_simd_approx8, math_approx::asin<8>) +TRIG_SIMD_BENCH (asin_simd_approx7, math_approx::asin<7>) +TRIG_SIMD_BENCH (asin_simd_approx6, math_approx::asin<6>) +TRIG_SIMD_BENCH (asin_simd_approx5, math_approx::asin<5>) +TRIG_SIMD_BENCH (asin_simd_approx4, math_approx::asin<4>) +TRIG_SIMD_BENCH (asin_simd_approx3, math_approx::asin<3>) + +BENCHMARK_MAIN(); diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index 97b7820..2690ae5 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -62,13 +62,11 @@ int main() { plt::figure(); const auto range = std::make_pair (-1.0f, 1.0f); - static constexpr auto tol = 1.0e-3f; + static constexpr auto tol = 1.0e-2f; const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); - const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (std::expm1)); - // plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asinh<5>)), "asinh-5"); - plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::expm1<5>) ), "expm1-5"); - plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::expm1<6>) ), "expm1-6"); + const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (std::asin)); + plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asin<3>) ), "asin-3"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);