Skip to content

Commit

Permalink
Add benchmarks and some more testing for wright-omega
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Nov 28, 2023
1 parent c9844b2 commit 83935b1
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
.vscode/

build*/

.DS_Store
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Currently supported:
- exp/exp2/exp10
- log/log2/log10
- tanh
- sigmoid
- sigmoid
- Wright-Omega function
43 changes: 38 additions & 5 deletions include/math_approx/src/wright_omega_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ T wright_omega (T x)
using S = scalar_of_t<T>;
static constexpr auto E = (S) 2.7182818284590452354;

const auto x1 = [](T _x)
const auto x1 = [] (T _x)
{
const auto x_sq = _x * _x;
if constexpr (poly_order == 3)
Expand All @@ -31,14 +31,14 @@ T wright_omega (T x)
}
else
{
return T{};
return T {};
}
} (x);
}(x);
const auto x2 = x - log<log_order> (x) + (S) 0.32352057096397160124 * exp<exp_order> ((S) -0.029614177658043381316 * x);

auto y = select (x < (S) -3, T{}, select (x < (S) E, x1, x2));
auto y = select (x < (S) -3, T {}, select (x < (S) E, x1, x2));

const auto nr_update = [](T _x, T _y)
const auto nr_update = [] (T _x, T _y)
{
return _y - (_y - exp<exp_order> (_x - _y)) / (_y + (S) 1);
};
Expand All @@ -48,4 +48,37 @@ T wright_omega (T x)

return y;
}

/**
* Wright-Omega function using Stephano D'Angelo's derivation (https://www.dafx.de/paper-archive/2019/DAFx2019_paper_5.pdf)
* With `num_nr_iters == 0`, this is the fastest implementation, but the least accurate.
* With `num_nr_iters == 1`, this is faster than the other implementation with 0 iterations, and little bit more accurate.
* For more accuracy, use the other implementation with at least 1 NR iteration.
*/
template <int num_nr_iters, int log_order = 3, int exp_order = log_order, typename T>
T wright_omega_dangelo (T x)
{
using S = scalar_of_t<T>;

const auto x1 = [] (T _x)
{
const auto x_sq = _x * _x;
const auto y_2_3 = (S) 4.775931364975583e-2 + (S) -1.314293149877800e-3 * _x;
const auto y_0_1 = (S) 6.313183464296682e-1 + (S) 3.631952663804445e-1 * _x;
return y_0_1 + y_2_3 * x_sq;
}(x);
const auto x2 = x - log<log_order> (x);

auto y = select (x < (S) -3.341459552768620, T {}, select (x < (S) 8, x1, x2));

const auto nr_update = [] (T _x, T _y)
{
return _y - (_y - exp<exp_order> (_x - _y)) / (_y + (S) 1);
};

for (int i = 0; i < num_nr_iters; ++i)
y = nr_update (x, y);

return y;
}
} // namespace math_approx
3 changes: 3 additions & 0 deletions tools/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ target_link_libraries(pow_approx_bench PRIVATE benchmark::benchmark math_approx)

add_executable(log_approx_bench log_bench.cpp)
target_link_libraries(log_approx_bench PRIVATE benchmark::benchmark math_approx)

add_executable(wright_omega_approx_bench wright_omega_bench.cpp)
target_link_libraries(wright_omega_approx_bench PRIVATE benchmark::benchmark math_approx)
67 changes: 67 additions & 0 deletions tools/bench/wright_omega_bench.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <math_approx/math_approx.hpp>
#include <benchmark/benchmark.h>

#include "../../test/src/reference/toms917.hpp"

static constexpr size_t N = 2000;
const auto data = []
{
std::vector<float> x;
x.resize (N, 0.0f);
for (size_t i = 0; i < N; ++i)
x[i] = -10.0f + 40.0f * (float) i / (float) N;
return x;
}();

#define WO_BENCH(name, func) \
void name (benchmark::State& state) \
{ \
for (auto _ : state) \
{ \
for (auto& x : data) \
{ \
auto y = func (x); \
benchmark::DoNotOptimize (y); \
} \
} \
} \
BENCHMARK (name);
WO_BENCH (wright_omega_toms917, toms917::wrightomega)
WO_BENCH (wright_omega_iter3_poly3_logexp5, (math_approx::wright_omega<3, 3, 5>))
WO_BENCH (wright_omega_iter3_poly3, (math_approx::wright_omega<3, 3>))
WO_BENCH (wright_omega_iter2_poly5, (math_approx::wright_omega<2, 5>))
WO_BENCH (wright_omega_iter2_poly3, (math_approx::wright_omega<2, 3>))
WO_BENCH (wright_omega_iter2_poly3_logexp3, (math_approx::wright_omega<2, 3, 3>))
WO_BENCH (wright_omega_iter1_poly5, (math_approx::wright_omega<1, 5>))
WO_BENCH (wright_omega_iter1_poly3, (math_approx::wright_omega<1, 3>))
WO_BENCH (wright_omega_iter0_poly5, (math_approx::wright_omega<0, 5>))
WO_BENCH (wright_omega_iter0_poly3, (math_approx::wright_omega<0, 3>))
WO_BENCH (wright_omega_dangelo2, (math_approx::wright_omega_dangelo<2>))
WO_BENCH (wright_omega_dangelo1, (math_approx::wright_omega_dangelo<1>))
WO_BENCH (wright_omega_dangelo0, (math_approx::wright_omega_dangelo<0>))

#define WO_SIMD_BENCH(name, func) \
void name (benchmark::State& state) \
{ \
for (auto _ : state) \
{ \
for (auto& x : data) \
{ \
auto y = func (xsimd::broadcast (x)); \
static_assert (std::is_same_v<xsimd::batch<float>, decltype(y)>); \
benchmark::DoNotOptimize (y); \
} \
} \
} \
BENCHMARK (name);
WO_SIMD_BENCH (wright_omega_simd_iter3_poly3_logexp5, (math_approx::wright_omega<3, 3, 5>))
WO_SIMD_BENCH (wright_omega_simd_iter3_poly3, (math_approx::wright_omega<3, 3>))
WO_SIMD_BENCH (wright_omega_simd_iter2_poly5, (math_approx::wright_omega<2, 5>))
WO_SIMD_BENCH (wright_omega_simd_iter2_poly3, (math_approx::wright_omega<2, 3>))
WO_SIMD_BENCH (wright_omega_simd_iter2_poly3_logexp3, (math_approx::wright_omega<2, 3, 3>))
WO_SIMD_BENCH (wright_omega_simd_iter1_poly5, (math_approx::wright_omega<1, 5>))
WO_SIMD_BENCH (wright_omega_simd_iter1_poly3, (math_approx::wright_omega<1, 3>))
WO_SIMD_BENCH (wright_omega_simd_iter0_poly5, (math_approx::wright_omega<0, 5>))
WO_SIMD_BENCH (wright_omega_simd_iter0_poly3, (math_approx::wright_omega<0, 3>))

BENCHMARK_MAIN();
9 changes: 7 additions & 2 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace plt = matplotlibcpp;

#include "../../test/src/test_helpers.hpp"
#include "../../test/src/reference/toms917.hpp"
#include "../../test/src/reference/dangelo_omega.hpp"
#include <math_approx/math_approx.hpp>

template <typename F_Approx>
Expand Down Expand Up @@ -66,8 +67,12 @@ int main()
const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol);
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC(toms917::wrightomega));

plot_ulp_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<3, 3, 4>)), "W-O 3-3-4");
plot_ulp_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<3, 3, 5>)), "W-O 3-3-5");
// plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<0>)), "W-O 0-3");
// plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<0, 5>)), "W-O 0-5");
plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega<1>)), "W-O 1-3");
// plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega_dangelo<0>)), "W-O D'Angelo 0");
plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega_dangelo<1>)), "W-O D'Angelo 1");
plot_error (all_floats, y_exact, FLOAT_FUNC((math_approx::wright_omega_dangelo<2>)), "W-O D'Angelo 2");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down

0 comments on commit 83935b1

Please sign in to comment.