Skip to content

Commit

Permalink
Add asin approximation
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Jan 5, 2024
1 parent 89af9be commit 14f2a67
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 5 deletions.
1 change: 1 addition & 0 deletions include/math_approx/math_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace math_approx
#include "src/basic_math.hpp"

#include "src/trig_approx.hpp"
#include "src/inverse_trig_approx.hpp"
#include "src/pow_approx.hpp"
#include "src/log_approx.hpp"
#include "src/tanh_approx.hpp"
Expand Down
135 changes: 135 additions & 0 deletions include/math_approx/src/inverse_trig_approx.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#pragma once

#include "basic_math.hpp"

namespace math_approx
{
namespace asin_detail
{
template <int order, typename T>
constexpr T asin_0_rsqrt2 (T x)
{
using S = scalar_of_t<T>;
static_assert (order >= 3 && order <= 11);

const auto x_sq = x * x;
if constexpr (order == 3)
{
const auto y_2_3 = (S) -0.0536922932754174 + (S) 0.297373838424192 * x;
return x + x_sq * y_2_3;
}
if constexpr (order == 4)
{
const auto y_3_4 = (S) -0.00535062837316264 + (S) 0.257252341545375 * x;
const auto y_2_3_4 = (S) 0.0317400592553864 + x * y_3_4;
return x + x_sq * y_2_3_4;
}
if constexpr (order == 5)
{
const auto y_4_5 = (S) -0.304640601352515 + (S) 0.353208342056560 * x;
const auto y_2_3 = (S) -0.0132122795426018 + (S) 0.278935718011026 * x;

const auto y_2_3_4_5 = y_2_3 + x_sq * y_4_5;
return x + x_sq * y_2_3_4_5;
}
if constexpr (order == 6)
{
const auto y_5_6 = (S) -0.516068582317285 + (S) 0.437544978265334 * x;
const auto y_3_4 = (S) 0.0946375652126262 + (S) 0.313911974469437 * x;

const auto y_3_4_5_6 = y_3_4 + x_sq * y_5_6;

const auto y_2_3_4_5_6 = (S) 0.00577946556085762 + x * y_3_4_5_6;
return x + x_sq * y_2_3_4_5_6;
}
if constexpr (order == 7)
{
const auto y_6_7 = (S) -0.993023225129115 + (S) 0.604213345541030 * x;
const auto y_4_5 = (S) -0.242568404368623 + (S) 0.780715776826480 * x;
const auto y_2_3 = (S) -0.00231743294930714 + (S) 0.205916081200406 * x;

const auto y_4_5_6_7 = y_4_5 + x_sq * y_6_7;

const auto y_2_3_4_5_6_7 = y_2_3 + x_sq * y_4_5_6_7;
return x + x_sq * y_2_3_4_5_6_7;
}
if constexpr (order == 8)
{
const auto y_7_8 = (S) -1.68181413527251 + (S) 0.833569228384441 * x;
const auto y_5_6 = (S) -0.614138628435564 + (S) 1.51390471735914 * x;
const auto y_3_4 = (S) 0.146440161696543 + (S) 0.167766328527588 * x;

const auto y_5_6_7_8 = y_5_6 + x_sq * y_7_8;
const auto y_3_4_5_6_7_8 = y_3_4 + x_sq * y_5_6_7_8;

const auto y_2_3_4_5_6_7_8 = (S) 0.000914775210828589 + x * y_3_4_5_6_7_8;
return x + x_sq * y_2_3_4_5_6_7_8;
}
if constexpr (order == 9)
{
const auto y_8_9 = (S) -2.8729113246543627191 + (S) 1.1910880616677141930 * x;
const auto y_6_7 = (S) -1.7250043993765906691 + (S) 3.0742940024198017746 * x;
const auto y_4_5 = (S) -0.10358468520191396745 + (S) 0.63814601829123507315 * x;
const auto y_2_3 = (S) -0.00034869418257434217938 + (S) 0.17639243703620430259 * x;

const auto y_6_7_8_9 = y_6_7 + x_sq * y_8_9;
const auto y_2_3_4_5 = y_2_3 + x_sq * y_4_5;

const auto y_2_3_4_5_6_7_8_9 = y_2_3_4_5 + (x_sq * x_sq) * y_6_7_8_9;
return x + x_sq * y_2_3_4_5_6_7_8_9;
}
if constexpr (order == 10)
{
const auto y_9_10 = (S) -4.7928604989214971255 + (S) 1.7203396621648587850 * x;
const auto y_7_8 = (S) -3.9882416242478972990 + (S) 5.9100541437570059955 * x;
const auto y_5_6 = (S) -0.33522250091818628359 + (S) 1.6520149306717599735 * x;
const auto y_3_4 = (S) 0.16219681678629885302 + (S) 0.059321118077451009953 * x;

const auto y_7_8_9_10 = y_7_8 + x_sq * y_9_10;
const auto y_3_4_5_6 = y_3_4 + x_sq * y_5_6;

const auto y_3_4_5_6_7_8_9_10 = y_3_4_5_6 + (x_sq * x_sq) * y_7_8_9_10;
const auto y_2_3_4_5_6_7_8_9_10 = (S) 0.00013025301412532343010 + x * y_3_4_5_6_7_8_9_10;
return x + x_sq * y_2_3_4_5_6_7_8_9_10;
}
if constexpr (order == 11)
{
const auto y_10_11 = (S) -7.8570177355488999282 + (S) 2.4911046380177723769 * x;
const auto y_8_9 = (S) -8.7527722722991097015 + (S) 11.027490029915364644 * x;
const auto y_6_7 = (S) -1.3541397019129590706 + (S) 4.3623048430828985644 * x;
const auto y_4_5 = (S) -0.031316768977302081312 + (S) 0.34169372825309551889 * x;
const auto y_2_3 = (S) -0.000047491279899706856284 + (S) 0.16861521810527382859 * x;

const auto y_8_9_10_11 = y_8_9 + x_sq * y_10_11;
const auto y_4_5_6_7 = y_4_5 + x_sq * y_6_7;

const auto y_4_5_6_7_8_9_10_11 = y_4_5_6_7 + (x_sq * x_sq) * y_8_9_10_11;
const auto y_2_3_4_5_6_7_8_9_10_11 = y_2_3 + x_sq * y_4_5_6_7_8_9_10_11;

return x + x_sq * y_2_3_4_5_6_7_8_9_10_11;
}
return {};
}
} // namespace asin_detail

/** Asin(x) approximation, valid on the range [-1, 1] */
template <int order, typename T>
T asin (T x)
{
using S = scalar_of_t<T>;
static constexpr auto rsqrt2 = (S) 0.707106781186547524400844362105;

using std::abs, std::sqrt;
#if defined(XSIMD_HPP)
using xsimd::abs, xsimd::sqrt;
#endif
const auto abs_x = abs (x);

const auto reflect = abs_x > rsqrt2;
const auto poly_arg = select (reflect, sqrt ((S) 1 - abs_x * abs_x), abs_x);
const auto poly_res = asin_detail::asin_0_rsqrt2<order> (poly_arg);
const auto res = select (reflect, (S) M_PI_2 - poly_res, poly_res);

return select (x > (S) 0, res, -res);
}
} // namespace math_approx
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ endfunction(setup_catch_test)
setup_catch_test(tanh_approx_test)
setup_catch_test(sigmoid_approx_test)
setup_catch_test(trig_approx_test)
setup_catch_test(inverse_trig_approx_test)
setup_catch_test(pow_approx_test)
setup_catch_test(log_approx_test)
setup_catch_test(wright_omega_approx_test)
Expand Down
108 changes: 108 additions & 0 deletions test/src/inverse_trig_approx_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include "test_helpers.hpp"
#include <catch2/catch_test_macros.hpp>
#include <iostream>

#include <math_approx/math_approx.hpp>

TEST_CASE ("Asin Approx Test")
{
#if ! defined(WIN32)
const auto all_floats = test_helpers::all_32_bit_floats (-1.0f, 1.0f, 1.0e-2f);
#else
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all<float> (all_floats, [] (auto x)
{ return std::asin (x); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound)
{
const auto y_approx = test_helpers::compute_all<float> (all_floats, f_approx);

const auto error = test_helpers::compute_error<float> (y_exact, y_approx);
const auto rel_error = test_helpers::compute_rel_error<float> (y_exact, y_approx);
const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx);

const auto max_error = test_helpers::abs_max<float> (error);
const auto max_rel_error = test_helpers::abs_max<float> (rel_error);
const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end());

std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl;
REQUIRE (std::abs (max_error) < err_bound);
REQUIRE (std::abs (max_rel_error) < rel_err_bound);
if (ulp_bound > 0)
REQUIRE (max_ulp_error < ulp_bound);
};

SECTION ("11th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<11> (x); },
2.5e-7f,
5.0e-7f,
8);
}
SECTION ("10th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<10> (x); },
2.5e-7f,
1.5e-6f,
22);
}
SECTION ("9th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<9> (x); },
4.0e-7f,
4.5e-6f,
72);
}
SECTION ("8th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<8> (x); },
8.0e-7f,
1.5e-5f,
250);
}
SECTION ("7th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<7> (x); },
3.0e-6f,
5.0e-5f,
0);
}
SECTION ("6th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<6> (x); },
1.5e-5f,
2.0e-4f,
0);
}
SECTION ("5th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<5> (x); },
5.5e-5f,
5.0e-4f,
0);
}
SECTION ("4th-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<4> (x); },
3.0e-4f,
2.0e-3f,
0);
}
SECTION ("3rd-Order")
{
test_approx ([] (auto x)
{ return math_approx::asin<3> (x); },
2.0e-3f,
6.0e-3f,
0);
}
}
3 changes: 3 additions & 0 deletions tools/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ target_link_libraries(sigmoid_approx_bench PRIVATE benchmark::benchmark math_app
add_executable(trig_approx_bench trig_bench.cpp)
target_link_libraries(trig_approx_bench PRIVATE benchmark::benchmark math_approx)

add_executable(inverse_trig_approx_bench inverse_trig_bench.cpp)
target_link_libraries(inverse_trig_approx_bench PRIVATE benchmark::benchmark math_approx)

add_executable(pow_approx_bench pow_bench.cpp)
target_link_libraries(pow_approx_bench PRIVATE benchmark::benchmark math_approx)

Expand Down
65 changes: 65 additions & 0 deletions tools/bench/inverse_trig_bench.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include <math_approx/math_approx.hpp>
#include <benchmark/benchmark.h>

static constexpr size_t N = 1000;
const auto data = []
{
std::vector<float> x;
x.resize (N, 0.0f);
for (size_t i = 0; i < N; ++i)
x[i] = -1.0f + 2.0f * (float) i / (float) N;
return x;
}();

#define TRIG_BENCH(name, func) \
void name (benchmark::State& state) \
{ \
for (auto _ : state) \
{ \
for (auto& x : data) \
{ \
auto y = func (x); \
benchmark::DoNotOptimize (y); \
} \
} \
} \
BENCHMARK (name);

// TRIG_BENCH (asin_std, std::asin)
// TRIG_BENCH (asin_approx11, math_approx::asin<11>)
// TRIG_BENCH (asin_approx10, math_approx::asin<10>)
// TRIG_BENCH (asin_approx9, math_approx::asin<9>)
// TRIG_BENCH (asin_approx8, math_approx::asin<8>)
// TRIG_BENCH (asin_approx7, math_approx::asin<7>)
// TRIG_BENCH (asin_approx6, math_approx::asin<6>)
// TRIG_BENCH (asin_approx5, math_approx::asin<5>)
// TRIG_BENCH (asin_approx4, math_approx::asin<4>)
// TRIG_BENCH (asin_approx3, math_approx::asin<3>)

#define TRIG_SIMD_BENCH(name, func) \
void name (benchmark::State& state) \
{ \
for (auto _ : state) \
{ \
for (auto& x : data) \
{ \
auto y = func (xsimd::broadcast (x)); \
static_assert (std::is_same_v<xsimd::batch<float>, decltype(y)>); \
benchmark::DoNotOptimize (y); \
} \
} \
} \
BENCHMARK (name);

TRIG_SIMD_BENCH (asin_xsimd, xsimd::asin)
TRIG_SIMD_BENCH (asin_simd_approx11, math_approx::asin<11>)
TRIG_SIMD_BENCH (asin_simd_approx10, math_approx::asin<10>)
TRIG_SIMD_BENCH (asin_simd_approx9, math_approx::asin<9>)
TRIG_SIMD_BENCH (asin_simd_approx8, math_approx::asin<8>)
TRIG_SIMD_BENCH (asin_simd_approx7, math_approx::asin<7>)
TRIG_SIMD_BENCH (asin_simd_approx6, math_approx::asin<6>)
TRIG_SIMD_BENCH (asin_simd_approx5, math_approx::asin<5>)
TRIG_SIMD_BENCH (asin_simd_approx4, math_approx::asin<4>)
TRIG_SIMD_BENCH (asin_simd_approx3, math_approx::asin<3>)

BENCHMARK_MAIN();
8 changes: 3 additions & 5 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ int main()
{
plt::figure();
const auto range = std::make_pair (-1.0f, 1.0f);
static constexpr auto tol = 1.0e-3f;
static constexpr auto tol = 1.0e-2f;

const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol);
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (std::expm1));
// plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asinh<5>)), "asinh-5");
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::expm1<5>) ), "expm1-5");
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::expm1<6>) ), "expm1-6");
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (std::asin));
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::asin<3>) ), "asin-3");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down

0 comments on commit 14f2a67

Please sign in to comment.