From 2db6631de1ebab9a7e1500eb0e7a03b662022c75 Mon Sep 17 00:00:00 2001 From: jatin Date: Sat, 6 Jan 2024 02:56:19 -0800 Subject: [PATCH] Add artcan --- README.md | 2 +- .../math_approx/src/inverse_trig_approx.hpp | 64 ++++++++++++++++++- test/src/inverse_trig_approx_test.cpp | 57 ++++++++++++++++- tools/bench/inverse_trig_bench.cpp | 10 +++ tools/plotter/plotter.cpp | 8 +-- 5 files changed, 133 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5d4a772..ddf9828 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Currently supported: - sin/cos/tan -- arcsin/arccos +- arcsin/arccos/arctan - exp/exp2/exp10/expm1 - log/log2/log10/log1p - sinh/cosh/tanh diff --git a/include/math_approx/src/inverse_trig_approx.hpp b/include/math_approx/src/inverse_trig_approx.hpp index 1e34924..531f888 100644 --- a/include/math_approx/src/inverse_trig_approx.hpp +++ b/include/math_approx/src/inverse_trig_approx.hpp @@ -65,7 +65,49 @@ namespace inv_trig_detail return {}; } } -} // namespace asin_detail + + template + constexpr T atan_kernel (T x) + { + using S = scalar_of_t; + static_assert (order >= 4 && order <= 7); + + if constexpr (order == 4) + { + const auto x_sq = x * x; + + const auto num = x + x_sq * (S) 0.498001992540; + const auto den = (S) 1 + x * (S) 0.481844539675 + x_sq * (S) 0.425470835319; + + return num / den; + } + else if constexpr (order == 5 || order == 6) + { + const auto x_sq = x * x; + + const auto num = (S) 0.177801521472 + x * (S) 0.116983970701; + const auto den = (S) 1 + x * (S) 0.174763903018 + x_sq * (S) 0.473808187566; + + return (x + x_sq * num) / den; + } + else if constexpr (order == 7) + { + const auto x_sq = x * x; + + const auto num = (S) 0.274959104817 + (S) 0.351814748865 * x + (S) -0.0395798531406 * x_sq; + const auto den = (S) 1 + x * ((S) 0.275079063405 + x * ((S) 0.683311392128 + x * (S) 0.0624877111229)); + + return (x + x_sq * num) / den; + + // an -> 0.274959104817, ad -> 0.275079063405, bn -> 0.351814748865, bd \ + // -> 0.683311392128, cn -> -0.0395798531406, cd -> 0.0624877111229 + } + else + { + return {}; + } + } +} // namespace inv_trig_detail template T asin (T x) @@ -112,4 +154,24 @@ T acos (T x) auto res = select (reflect, (S) M_PI_2 - (z2 + z2), z2); return (S) M_PI_2 - select (x > (S) 0, res, -res); } + +template +T atan (T x) +{ + using S = scalar_of_t; + + using std::abs, std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs, xsimd::sqrt; +#endif + + const auto abs_x = abs (x); + const auto reflect = abs_x > (S) 1; + + const auto z = select (reflect, (S) 1 / abs_x, abs_x); + const auto atan_01 = inv_trig_detail::atan_kernel (z); + + const auto res = select (reflect, (S) M_PI_2 - atan_01, atan_01); + return select (x > (S) 0, res, -res); +} } // namespace math_approx diff --git a/test/src/inverse_trig_approx_test.cpp b/test/src/inverse_trig_approx_test.cpp index ee76848..d07946d 100644 --- a/test/src/inverse_trig_approx_test.cpp +++ b/test/src/inverse_trig_approx_test.cpp @@ -3,7 +3,7 @@ #include #include - +/* TEST_CASE ("Asin Approx Test") { #if ! defined(WIN32) @@ -120,3 +120,58 @@ TEST_CASE ("Acos Approx Test") 5.0e-3f); } } +*/ +TEST_CASE ("Atan Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-2f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::atan (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto rel_error = test_helpers::compute_rel_error (y_exact, y_approx); + const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx); + + const auto max_error = test_helpers::abs_max (error); + const auto max_rel_error = test_helpers::abs_max (rel_error); + const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end()); + + std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + REQUIRE (std::abs (max_rel_error) < rel_err_bound); + if (ulp_bound > 0) + REQUIRE (max_ulp_error < ulp_bound); + }; + + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::atan<7> (x); }, + 4.0e-7f, + 3.0e-6f, + 45); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::atan<5> (x); }, + 2.0e-5f, + 1.5e-4f, + 0); + } + SECTION ("4th-Order") + { + test_approx ([] (auto x) + { return math_approx::atan<4> (x); }, + 1.5e-4f, + 8.5e-4f, + 0); + } +} diff --git a/tools/bench/inverse_trig_bench.cpp b/tools/bench/inverse_trig_bench.cpp index 2e19032..f5705d8 100644 --- a/tools/bench/inverse_trig_bench.cpp +++ b/tools/bench/inverse_trig_bench.cpp @@ -38,6 +38,11 @@ TRIG_BENCH (acos_approx3, math_approx::acos<3>) TRIG_BENCH (acos_approx2, math_approx::acos<2>) TRIG_BENCH (acos_approx1, math_approx::acos<1>) +TRIG_BENCH (atan_std, std::atan) +TRIG_BENCH (atan_approx7, math_approx::atan<7>) +TRIG_BENCH (atan_approx5, math_approx::atan<5>) +TRIG_BENCH (atan_approx4, math_approx::atan<4>) + #define TRIG_SIMD_BENCH(name, func) \ void name (benchmark::State& state) \ { \ @@ -66,4 +71,9 @@ TRIG_SIMD_BENCH (acos_simd_approx3, math_approx::acos<3>) TRIG_SIMD_BENCH (acos_simd_approx2, math_approx::acos<2>) TRIG_SIMD_BENCH (acos_simd_approx1, math_approx::acos<1>) +TRIG_SIMD_BENCH (atan_xsimd, xsimd::atan) +TRIG_SIMD_BENCH (atan_simd_approx7, math_approx::atan<7>) +TRIG_SIMD_BENCH (atan_simd_approx5, math_approx::atan<5>) +TRIG_SIMD_BENCH (atan_simd_approx4, math_approx::atan<4>) + BENCHMARK_MAIN(); diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index d249631..d939c12 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -61,14 +61,12 @@ void plot_function (std::span all_floats, int main() { plt::figure(); - const auto range = std::make_pair (-1.0f, 1.0f); + const auto range = std::make_pair (-10.0f, 10.0f); static constexpr auto tol = 1.0e-2f; const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); - const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (std::acos)); - // plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((acos_xsimd) ), "acos-xsimd"); - plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::acos<4>) ), "acos-4"); - // plot_function (all_floats, FLOAT_FUNC ((math_approx::acos<4>) ), "acos-4"); + const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (std::atan)); + plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::atan<4>) ), "atan-4"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);