Skip to content

Commit

Permalink
Trig approx improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Nov 22, 2023
1 parent d06887a commit a1590f8
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
- name: CMake Test
run: |
ctest --test-dir build -C RelWithDebInfo --show-only
ctest --test-dir build -C RelWithDebInfo -j4 --output-on-failure
ctest --test-dir build -C RelWithDebInfo -j2 --output-on-failure
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ endif()

add_library(math_approx INTERFACE)
target_include_directories(math_approx INTERFACE include)
if(MSVC)
target_compile_definitions(math_approx INTERFACE _USE_MATH_DEFINES=1)
endif()
if (TARGET xsimd)
message(STATUS "math_approx -- Linking with XSIMD...")
target_link_libraries(math_approx INTERFACE xsimd)
Expand Down
2 changes: 1 addition & 1 deletion include/math_approx/math_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ namespace math_approx

#include "src/tanh_approx.hpp"
#include "src/sigmoid_approx.hpp"
#include "src/sin_approx.hpp"
#include "src/trig_approx.hpp"
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,16 @@ T cos_mpi_pi (T x)

using S = scalar_of_t<T>;
static constexpr auto pi = static_cast<S> (M_PI);
static constexpr auto pi_o_2 = pi * (S) 0.5;;
static constexpr auto pi_sq = pi * pi;
static constexpr auto pi_o_2 = pi * (S) 0.5;

using std::abs;
#if defined(XSIMD_HPP)
using xsimd::abs;
#endif
x = abs (x);

const auto hpmx = (x > (S) 0 ? (S) 1 : (S) -1) * pi_o_2 - x;
const auto thpmx = (x > (S) 0 ? (S) 3 : (S) -3) * pi_o_2 - x;
const auto nhpmx = (x > (S) 0 ? (S) -1 : (S) 1) * pi_o_2 - x;
const auto hpmx = pi_o_2 - x;
const auto hpmx_sq = hpmx * hpmx;

T x_poly {};
Expand All @@ -114,7 +119,7 @@ T cos_mpi_pi (T x)
else if constexpr (order == 5)
x_poly = sin_detail::sin_poly_5 (hpmx, hpmx_sq);

return thpmx * nhpmx * (x > (S) 0 ? (S) -1 : (S) 1) * x_poly;
return (pi_sq - hpmx_sq) * x_poly;
}

template <int order, typename T>
Expand Down
3 changes: 1 addition & 2 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@ endfunction(setup_catch_test)

setup_catch_test(tanh_approx_test)
setup_catch_test(sigmoid_approx_test)
setup_catch_test(sin_approx_test)
setup_catch_test(cos_approx_test)
setup_catch_test(trig_approx_test)
46 changes: 0 additions & 46 deletions test/src/cos_approx_test.cpp

This file was deleted.

41 changes: 41 additions & 0 deletions test/src/sin_approx_test.cpp → test/src/trig_approx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,44 @@ TEST_CASE ("Sine Approx Test")
7.5e-4f);
}
}

TEST_CASE ("Cosine Approx Test")
{
#if ! defined(WIN32)
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f);
#else
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x)
{ return std::cos (x); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound)
{
const auto y_approx = test_helpers::compute_all (all_floats, f_approx);

const auto error = test_helpers::compute_error (y_exact, y_approx);
const auto max_error = test_helpers::abs_max (error);

// std::cout << max_error << std::endl;
REQUIRE (std::abs (max_error) < err_bound);
};

SECTION ("9th-Order")
{
test_approx ([] (auto x)
{ return math_approx::cos<9> (x); },
7.5e-7f);
}
SECTION ("7th-Order")
{
test_approx ([] (auto x)
{ return math_approx::cos<7> (x); },
1.8e-5f);
}
SECTION ("5th-Order")
{
test_approx ([] (auto x)
{ return math_approx::cos<5> (x); },
7.5e-4f);
}
}
3 changes: 3 additions & 0 deletions tools/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ target_link_libraries(tanh_approx_bench PRIVATE benchmark::benchmark math_approx

add_executable(sigmoid_approx_bench sigmoid_bench.cpp)
target_link_libraries(sigmoid_approx_bench PRIVATE benchmark::benchmark math_approx)

add_executable(trig_approx_bench trig_bench.cpp)
target_link_libraries(trig_approx_bench PRIVATE benchmark::benchmark math_approx)
63 changes: 63 additions & 0 deletions tools/bench/trig_bench.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <math_approx/math_approx.hpp>
#include <benchmark/benchmark.h>

static constexpr size_t N = 2000;
const auto data = []
{
std::vector<float> x;
x.resize (N, 0.0f);
for (size_t i = 0; i < N; ++i)
x[i] = -10.0f + 20.0f * (float) i / (float) N;
return x;
}();

#define TRIG_BENCH(name, func) \
void name (benchmark::State& state) \
{ \
for (auto _ : state) \
{ \
for (auto& x : data) \
{ \
auto y = func (x); \
benchmark::DoNotOptimize (y); \
} \
} \
} \
BENCHMARK (name);

TRIG_BENCH (cos_std, std::cos)
TRIG_BENCH (cos_approx9, math_approx::cos<9>)
TRIG_BENCH (cos_approx7, math_approx::cos<7>)
TRIG_BENCH (cos_approx5, math_approx::cos<5>)

TRIG_BENCH (sin_std, std::sin)
TRIG_BENCH (sin_approx9, math_approx::sin<9>)
TRIG_BENCH (sin_approx7, math_approx::sin<7>)
TRIG_BENCH (sin_approx5, math_approx::sin<5>)

#define TRIG_SIMD_BENCH(name, func) \
void name (benchmark::State& state) \
{ \
for (auto _ : state) \
{ \
for (auto& x : data) \
{ \
auto y = func (xsimd::broadcast (x)); \
static_assert (std::is_same_v<xsimd::batch<float>, decltype(y)>); \
benchmark::DoNotOptimize (y); \
} \
} \
} \
BENCHMARK (name);

TRIG_SIMD_BENCH (sin_xsimd, xsimd::sin)
TRIG_SIMD_BENCH (sin_simd_approx9, math_approx::sin<9>)
TRIG_SIMD_BENCH (sin_simd_approx7, math_approx::sin<7>)
TRIG_SIMD_BENCH (sin_simd_approx5, math_approx::sin<5>)

TRIG_SIMD_BENCH (cos_xsimd, xsimd::cos)
TRIG_SIMD_BENCH (cos_simd_approx9, math_approx::cos<9>)
TRIG_SIMD_BENCH (cos_simd_approx7, math_approx::cos<7>)
TRIG_SIMD_BENCH (cos_simd_approx5, math_approx::cos<5>)

BENCHMARK_MAIN();
4 changes: 2 additions & 2 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ int main()

// // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<5> (x); }, "Sin-5");
// // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<7> (x); }, "Sin-7");
plot_ulp_error (all_floats, y_exact, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9");
// plot_function (all_floats, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9");
// plot_ulp_error (all_floats, y_exact, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9");
plot_function (all_floats, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down

0 comments on commit a1590f8

Please sign in to comment.