Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative sigmoid approximation #3

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions include/math_approx/src/sigmoid_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,18 @@ T sigmoid (T x)
return (S) 0.5 * x_poly * rsqrt (x_poly * x_poly + (S) 1) + (S) 0.5;
}

// So far this has tested slower than the above approx (for equivalent error),
// but maybe it will be useful for someone!
// template <int order, typename T>
// T sigmoid_exp (T x)
// {
// return (T) 1 / ((T) 1 + math_approx::exp<order> (-x));
// }

/**
* Approximation of sigmoid(x) := 1 / (1 + e^-x),
* using math_approx::exp (x).
*
* So far this has tested slower than the above approximation
* for similar absolute error, but has better relative error
* characteristics.
*/
template <int order, bool C1_continuous = false, typename T>
T sigmoid_exp (T x)
{
return (T) 1 / ((T) 1 + math_approx::exp<order, C1_continuous> (-x));
}
} // namespace math_approx
68 changes: 67 additions & 1 deletion test/src/sigmoid_approx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST_CASE ("Sigmoid Approx Test")
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all<float> (all_floats, [] (auto x)
{ return 1.0f / (1.0f + std::exp (-x)); });
{ return 1.0f / (1.0f + std::exp (-x)); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound)
{
Expand Down Expand Up @@ -50,3 +50,69 @@ TEST_CASE ("Sigmoid Approx Test")
2.0e-3f);
}
}

TEST_CASE ("Sigmoid (Exp) Approx Test")
{
#if ! defined(WIN32)
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f);
#else
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all<float> (all_floats, [] (auto x)
{ return 1.0f / (1.0f + std::exp (-x)); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound)
{
const auto y_approx = test_helpers::compute_all<float> (all_floats, f_approx);

const auto error = test_helpers::compute_error<float> (y_exact, y_approx);
const auto rel_error = test_helpers::compute_rel_error<float> (y_exact, y_approx);
const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx);

const auto max_error = test_helpers::abs_max<float> (error);
const auto max_rel_error = test_helpers::abs_max<float> (rel_error);
const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end());

std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl;
REQUIRE (std::abs (max_error) < err_bound);
REQUIRE (std::abs (max_rel_error) < rel_err_bound);
if (ulp_bound > 0)
REQUIRE (max_ulp_error < ulp_bound);
};

SECTION ("6th-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<6> (x); },
1.5e-7f,
6.5e-7f,
12);
}

SECTION ("5th-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<5> (x); },
1.5e-7f,
7.5e-7f,
12);
}

SECTION ("4th-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<4> (x); },
9.5e-7f,
4.5e-6f,
65);
}

SECTION ("3rd-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<3> (x); },
3.0e-4f,
1.5e-4f,
0);
}
}
6 changes: 6 additions & 0 deletions tools/bench/sigmoid_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ SIGMOID_BENCH (sigmoid_std, [] (auto x) { return 1.0f / (1.0f + std::exp (-x));
SIGMOID_BENCH (sigmoid_approx9, math_approx::sigmoid<9>)
SIGMOID_BENCH (sigmoid_approx7, math_approx::sigmoid<7>)
SIGMOID_BENCH (sigmoid_approx5, math_approx::sigmoid<5>)
SIGMOID_BENCH (sigmoid_exp_approx6, math_approx::sigmoid_exp<6>)
SIGMOID_BENCH (sigmoid_exp_approx5, math_approx::sigmoid_exp<5>)
SIGMOID_BENCH (sigmoid_exp_approx4, math_approx::sigmoid_exp<4>)

#define SIGMOID_SIMD_BENCH(name, func) \
void name (benchmark::State& state) \
Expand All @@ -47,5 +50,8 @@ SIGMOID_SIMD_BENCH (sigmoid_xsimd, [] (auto x) { return 1.0f / (1.0f + xsimd::ex
SIGMOID_SIMD_BENCH (sigmoid_simd_approx9, math_approx::tanh<9>)
SIGMOID_SIMD_BENCH (sigmoid_simd_approx7, math_approx::tanh<7>)
SIGMOID_SIMD_BENCH (sigmoid_simd_approx5, math_approx::tanh<5>)
SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx6, math_approx::sigmoid_exp<6>)
SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx5, math_approx::sigmoid_exp<5>)
SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx4, math_approx::sigmoid_exp<4>)

BENCHMARK_MAIN();
13 changes: 10 additions & 3 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,24 @@ void plot_function (std::span<const float> all_floats,
plt::named_plot<float, float> (name, all_floats, y_approx);
}

template <typename T>
T sigmoid_ref (T x)
{
return (T) 1 / ((T) 1 + std::exp (-x));
}

#define FLOAT_FUNC(func) [] (float x) { return func (x); }

int main()
{
plt::figure();
const auto range = std::make_pair (1.0f, 10.0f);
const auto range = std::make_pair (-10.0f, 10.0f);
static constexpr auto tol = 1.0e-2f;

const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol);
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (std::acosh));
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::acosh<5>) ), "acosh-5");
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (sigmoid_ref));
plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::sigmoid_exp<5, true>) ), "sigmoid_exp-5_c1");
plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::sigmoid_exp<6, true>) ), "sigmoid_exp-6_c1");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down
Loading