From 0c68d4d17242d707ba07fa7f1901692b7ed72d58 Mon Sep 17 00:00:00 2001 From: jatinchowdhury18 Date: Fri, 19 Jan 2024 00:12:05 -0800 Subject: [PATCH] Add alternative sigmoid approximation (#3) * Add sigmoid_exp approximation * Undo comments * Tweaking error bounds --- include/math_approx/src/sigmoid_approx.hpp | 21 ++++--- test/src/sigmoid_approx_test.cpp | 68 +++++++++++++++++++++- tools/bench/sigmoid_bench.cpp | 6 ++ tools/plotter/plotter.cpp | 13 ++++- 4 files changed, 97 insertions(+), 11 deletions(-) diff --git a/include/math_approx/src/sigmoid_approx.hpp b/include/math_approx/src/sigmoid_approx.hpp index 7605f57..68a9ffe 100644 --- a/include/math_approx/src/sigmoid_approx.hpp +++ b/include/math_approx/src/sigmoid_approx.hpp @@ -75,11 +75,18 @@ T sigmoid (T x) return (S) 0.5 * x_poly * rsqrt (x_poly * x_poly + (S) 1) + (S) 0.5; } -// So far this has tested slower than the above approx (for equivalent error), -// but maybe it will be useful for someone! -// template -// T sigmoid_exp (T x) -// { -// return (T) 1 / ((T) 1 + math_approx::exp (-x)); -// } + +/** + * Approximation of sigmoid(x) := 1 / (1 + e^-x), + * using math_approx::exp (x). + * + * So far this has tested slower than the above approximation + * for similar absolute error, but has better relative error + * characteristics. + */ +template +T sigmoid_exp (T x) +{ + return (T) 1 / ((T) 1 + math_approx::exp (-x)); +} } // namespace math_approx diff --git a/test/src/sigmoid_approx_test.cpp b/test/src/sigmoid_approx_test.cpp index b444325..2cddbdd 100644 --- a/test/src/sigmoid_approx_test.cpp +++ b/test/src/sigmoid_approx_test.cpp @@ -12,7 +12,7 @@ TEST_CASE ("Sigmoid Approx Test") const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); #endif const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) - { return 1.0f / (1.0f + std::exp (-x)); }); + { return 1.0f / (1.0f + std::exp (-x)); }); const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound) { @@ -50,3 +50,69 @@ TEST_CASE ("Sigmoid Approx Test") 2.0e-3f); } } + +TEST_CASE ("Sigmoid (Exp) Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return 1.0f / (1.0f + std::exp (-x)); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto rel_error = test_helpers::compute_rel_error (y_exact, y_approx); + const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx); + + const auto max_error = test_helpers::abs_max (error); + const auto max_rel_error = test_helpers::abs_max (rel_error); + const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end()); + + std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + REQUIRE (std::abs (max_rel_error) < rel_err_bound); + if (ulp_bound > 0) + REQUIRE (max_ulp_error < ulp_bound); + }; + + SECTION ("6th-Order (Exp)") + { + test_approx ([] (auto x) + { return math_approx::sigmoid_exp<6> (x); }, + 1.5e-7f, + 6.5e-7f, + 12); + } + + SECTION ("5th-Order (Exp)") + { + test_approx ([] (auto x) + { return math_approx::sigmoid_exp<5> (x); }, + 1.5e-7f, + 7.5e-7f, + 12); + } + + SECTION ("4th-Order (Exp)") + { + test_approx ([] (auto x) + { return math_approx::sigmoid_exp<4> (x); }, + 9.5e-7f, + 4.5e-6f, + 65); + } + + SECTION ("3rd-Order (Exp)") + { + test_approx ([] (auto x) + { return math_approx::sigmoid_exp<3> (x); }, + 3.0e-4f, + 1.5e-4f, + 0); + } +} diff --git a/tools/bench/sigmoid_bench.cpp b/tools/bench/sigmoid_bench.cpp index 5f868f5..606321e 100644 --- a/tools/bench/sigmoid_bench.cpp +++ b/tools/bench/sigmoid_bench.cpp @@ -28,6 +28,9 @@ SIGMOID_BENCH (sigmoid_std, [] (auto x) { return 1.0f / (1.0f + std::exp (-x)); SIGMOID_BENCH (sigmoid_approx9, math_approx::sigmoid<9>) SIGMOID_BENCH (sigmoid_approx7, math_approx::sigmoid<7>) SIGMOID_BENCH (sigmoid_approx5, math_approx::sigmoid<5>) +SIGMOID_BENCH (sigmoid_exp_approx6, math_approx::sigmoid_exp<6>) +SIGMOID_BENCH (sigmoid_exp_approx5, math_approx::sigmoid_exp<5>) +SIGMOID_BENCH (sigmoid_exp_approx4, math_approx::sigmoid_exp<4>) #define SIGMOID_SIMD_BENCH(name, func) \ void name (benchmark::State& state) \ @@ -47,5 +50,8 @@ SIGMOID_SIMD_BENCH (sigmoid_xsimd, [] (auto x) { return 1.0f / (1.0f + xsimd::ex SIGMOID_SIMD_BENCH (sigmoid_simd_approx9, math_approx::tanh<9>) SIGMOID_SIMD_BENCH (sigmoid_simd_approx7, math_approx::tanh<7>) SIGMOID_SIMD_BENCH (sigmoid_simd_approx5, math_approx::tanh<5>) +SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx6, math_approx::sigmoid_exp<6>) +SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx5, math_approx::sigmoid_exp<5>) +SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx4, math_approx::sigmoid_exp<4>) BENCHMARK_MAIN(); diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index be727cb..1021eef 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -56,17 +56,24 @@ void plot_function (std::span all_floats, plt::named_plot (name, all_floats, y_approx); } +template +T sigmoid_ref (T x) +{ + return (T) 1 / ((T) 1 + std::exp (-x)); +} + #define FLOAT_FUNC(func) [] (float x) { return func (x); } int main() { plt::figure(); - const auto range = std::make_pair (1.0f, 10.0f); + const auto range = std::make_pair (-10.0f, 10.0f); static constexpr auto tol = 1.0e-2f; const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); - const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (std::acosh)); - plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::acosh<5>) ), "acosh-5"); + const auto y_exact = test_helpers::compute_all (all_floats, FLOAT_FUNC (sigmoid_ref)); + plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::sigmoid_exp<5, true>) ), "sigmoid_exp-5_c1"); + plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::sigmoid_exp<6, true>) ), "sigmoid_exp-6_c1"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);