Skip to content

Commit e2e1b10

Browse files
committed
refactor: replace boolean template parameters with enum-based operation dispatch
- Replace IS_COS boolean template parameters with Operation enum class (kSin/kCos) - Introduce NPSR_INTRIN macro for consistent intrinsic function attributes - Update trigonometric function implementations to use new enum-based dispatch - Improve code readability and maintainability by eliminating boolean parameter confusion - Update function signatures and documentation to reflect new API design - Replace HWY_ATTR with HWY_INLINE for better consistency in LUT implementation
1 parent aa6dd90 commit e2e1b10

File tree

6 files changed

+67
-52
lines changed

6 files changed

+67
-52
lines changed

npsr/hwy.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
#define NPSR_HWY_H_
33

44
#include <hwy/highway.h>
5-
5+
// This macro is used to define intrinsics that are:
6+
// NOTE: equals to HWY_API.
7+
// - always inlined
8+
// - flattened (no separate stack frame)
9+
// - marked maybe unused to suppress warnings when they are not used
10+
// NOTE: we do not need to use HWY_ATTR because we wrap Highway intrinsics in
11+
// HWY_BEFORE_NAMESPACE()/HWY_AFTER_NAMESPACE()
12+
// which implies the nessessary target attributes via #pargma.
13+
#define NPSR_INTRIN static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED
614
#endif // NPSR_HWY_H_
715

816
#if defined(NPSR_HWY_FOREACH_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT

npsr/lut-inl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Lut {
7676
* @note Index values must be in range [0, kCols)
7777
*/
7878
template <typename VU, typename... OutV>
79-
HWY_ATTR void Load(VU idx, OutV &...out) const {
79+
HWY_INLINE void Load(VU idx, OutV &...out) const {
8080
static_assert(sizeof...(OutV) == kRows,
8181
"Number of output vectors must match number of rows in LUT");
8282
using namespace hn;
@@ -102,7 +102,7 @@ class Lut {
102102

103103
/// Dispatch to optimal row-load implementation based on vector/LUT size
104104
template <size_t Off = 0, typename VU, typename... OutV>
105-
HWY_ATTR void LoadRow_(VU idx, OutV &...out) const {
105+
HWY_INLINE void LoadRow_(VU idx, OutV &...out) const {
106106
using namespace hn;
107107
using DU = DFromV<VU>;
108108
const DU du;
@@ -126,7 +126,7 @@ class Lut {
126126

127127
// Load using single table lookup (vector size == table width)
128128
template <size_t Off = 0, typename VInd, typename OutV0, typename... OutV>
129-
HWY_ATTR void LoadX1_(const VInd &ind, OutV0 &out0, OutV &...out) const {
129+
HWY_INLINE void LoadX1_(const VInd &ind, OutV0 &out0, OutV &...out) const {
130130
using namespace hn;
131131
using D = DFromV<OutV0>;
132132
const D d;
@@ -141,7 +141,7 @@ class Lut {
141141

142142
// Load using two table lookups (vector size == table width / 2)
143143
template <size_t Off = 0, typename VInd, typename OutV0, typename... OutV>
144-
HWY_ATTR void LoadX2_(const VInd &ind, OutV0 &out0, OutV &...out) const {
144+
HWY_INLINE void LoadX2_(const VInd &ind, OutV0 &out0, OutV &...out) const {
145145
using namespace hn;
146146
using D = DFromV<OutV0>;
147147
const D d;
@@ -158,7 +158,7 @@ class Lut {
158158

159159
// General fallback using gather instructions
160160
template <size_t Off = 0, typename VU, typename OutV0, typename... OutV>
161-
HWY_ATTR void LoadGather_(const VU &idx, OutV0 &out0, OutV &...out) const {
161+
HWY_INLINE void LoadGather_(const VU &idx, OutV0 &out0, OutV &...out) const {
162162
using namespace hn;
163163
using D = DFromV<OutV0>;
164164
const D d;

npsr/trig/extended-inl.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
#include "npsr/hwy.h"
99
#include "npsr/trig/data/data.h"
10+
#include "npsr/trig/low-inl.h" // Operation
1011

1112
HWY_BEFORE_NAMESPACE();
1213
namespace npsr::HWY_NAMESPACE::trig {
1314

14-
template <bool IS_COS, class V>
15-
HWY_API V Extended(V x) {
15+
template <Operation OP, class V>
16+
NPSR_INTRIN V Extended(V x) {
1617
using namespace hn;
1718
namespace data = ::npsr::trig::data;
1819
using hwy::ExponentBits;
@@ -143,9 +144,8 @@ HWY_API V Extended(V x) {
143144
// PHASE 5: Extract Quotient and Fractional Parts
144145
// =============================================================================
145146

146-
// Extract integer quotient
147-
constexpr int kQuotientShift =
148-
ExponentBits<T>() + 1; // 9 for F32, 12 for F64
147+
// Extract integer quotient. 9 for F32, 12 for F64
148+
constexpr int kQuotientShift = ExponentBits<T>() + 1;
149149
VU u_shifted_n = ShiftRight<kQuotientShift>(u_n_hi);
150150

151151
// fractional shifts derived from magic constants
@@ -255,8 +255,8 @@ HWY_API V Extended(V x) {
255255
// =============================================================================
256256

257257
// Generated by npsr/trig/data/approx.h.sol
258-
const T *table_base =
259-
IS_COS ? data::kCosApproxTable<T> : data::kSinApproxTable<T>;
258+
const T *table_base = OP == Operation::kCos ? data::kCosApproxTable<T>
259+
: data::kSinApproxTable<T>;
260260

261261
// Calculate table index
262262
VU u_n_mask = Set(du, kIsSingle ? 0xFF : 0x1FF);

npsr/trig/high-inl.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
#include "npsr/hwy.h"
99
#include "npsr/lut-inl.h"
1010
#include "npsr/trig/data/data.h"
11+
#include "npsr/trig/low-inl.h" // Operation
1112

1213
HWY_BEFORE_NAMESPACE();
1314

1415
namespace npsr::HWY_NAMESPACE::trig {
1516

16-
template <bool IS_COS, typename V, HWY_IF_F32(TFromV<V>)>
17-
HWY_INLINE V High(V x) {
17+
template <Operation OP, typename V, HWY_IF_F32(TFromV<V>)>
18+
NPSR_INTRIN V High(V x) {
1819
using namespace hn;
1920
namespace data = ::npsr::trig::data;
2021

@@ -37,7 +38,7 @@ HWY_INLINE V High(V x) {
3738
// Transform cosine to sine using identity: cos(x) = sin(x + π/2)
3839
const V half_pi = Set(d, data::kHalfPi<T>);
3940
V x_trans = x_abs;
40-
if constexpr (IS_COS) {
41+
if constexpr (OP == Operation::kCos) {
4142
x_trans = Add(x_abs, half_pi);
4243
}
4344
// check zero input/subnormal for cosine (cos(~0) = 1)
@@ -49,7 +50,7 @@ HWY_INLINE V High(V x) {
4950
V n = Sub(n_biased, magic_round);
5051

5152
// Adjust quotient for cosine (accounts for π/2 phase shift)
52-
if constexpr (IS_COS) {
53+
if constexpr (OP == Operation::kCos) {
5354
// For cosine, we computed N = round((x + π/2)/π) but need N' for x:
5455
// N = round((x + π/2)/π) = round(x/π + 0.5)
5556
// This is often 1 more than round(x/π), so we subtract 0.5:
@@ -83,7 +84,7 @@ HWY_INLINE V High(V x) {
8384
// Extract octant sign information from quotient and flip the sign bit
8485
poly = Xor(poly,
8586
BitCast(d, ShiftLeft<sizeof(T) * 8 - 1>(BitCast(du, n_biased))));
86-
if constexpr (IS_COS) {
87+
if constexpr (OP == Operation::kCos) {
8788
poly = IfThenElse(is_cos_near_zero, Set(d, 1.0f), poly);
8889
} else {
8990
// Restore original sign for sine (odd function)
@@ -113,8 +114,8 @@ HWY_INLINE V High(V x) {
113114
* - cos(x) = cos(n*π/16 + r) = cos(n*π/16)*cos(r) - sin(n*π/16)*sin(r)
114115
*
115116
*/
116-
template <bool IS_COS, typename V, HWY_IF_F64(TFromV<V>)>
117-
HWY_INLINE V High(V x) {
117+
template <Operation OP, typename V, HWY_IF_F64(TFromV<V>)>
118+
NPSR_INTRIN V High(V x) {
118119
using namespace hn;
119120
namespace data = ::npsr::trig::data;
120121
using T = TFromV<V>;
@@ -206,7 +207,7 @@ HWY_INLINE V High(V x) {
206207
// sin(n*π/16 + r) = sin_table + cos_table*remainder (+ corrections)
207208
// cos(n*π/16 + r) = cos_table - sin_table*remainder (+ corrections)
208209
V result;
209-
if constexpr (IS_COS) {
210+
if constexpr (OP == Operation::kCos) {
210211
// Cosine reconstruction: cos_table - sin_table*remainder
211212
// Equivalent to: cos(a)*cos(r) - sin(a)*sin(r) but more efficient
212213
V res_hi = NegMulAdd(r, sin_hi, cos_hi); // cos_hi - r*sin_hi

npsr/trig/inl.h

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
HWY_BEFORE_NAMESPACE();
2424

2525
namespace npsr::HWY_NAMESPACE::trig {
26-
2726
/**
2827
* @brief Unified sine/cosine implementation with configurable precision
2928
*
@@ -32,13 +31,13 @@ namespace npsr::HWY_NAMESPACE::trig {
3231
* - Input magnitude (standard vs extended precision for large arguments)
3332
* - Special case handling (NaN, Inf)
3433
*
35-
* @tparam IS_COS true for cosine, false for sine
34+
* @tparam OP Operation type: kSin or kCos
3635
* @tparam Prec Precise configuration class with accuracy/feature flags
3736
* @tparam V Highway vector type
3837
*
3938
* @param prec Precise object controlling FP environment and exceptions
4039
* @param x Input vector
41-
* @return sin(x) or cos(x) depending on IS_COS
40+
* @return sin(x) or cos(x) depending on OP
4241
*
4342
* Algorithm selection:
4443
* 1. If kLowAccuracy: Use Low<> (Cody-Waite with minimal polynomial)
@@ -49,8 +48,8 @@ namespace npsr::HWY_NAMESPACE::trig {
4948
* - Float: |x| > 10,000 (empirically chosen for accuracy)
5049
* - Double: |x| > 2^24 (16,777,216 - where 53-bit mantissa loses precision)
5150
*/
52-
template <bool IS_COS, typename Prec, typename V>
53-
HWY_API V SinCos(Prec &prec, V x) {
51+
template <Operation OP, typename Prec, typename V>
52+
NPSR_INTRIN V Trig(Prec &prec, V x) {
5453
using namespace hwy::HWY_NAMESPACE;
5554
constexpr bool kIsSingle = std::is_same_v<TFromV<V>, float>;
5655
const DFromV<V> d;
@@ -59,19 +58,19 @@ HWY_API V SinCos(Prec &prec, V x) {
5958
if constexpr (Prec::kLowAccuracy) {
6059
// Low precision: Cody-Waite reduction with degree-9 polynomial
6160
// Error: ~2 ULP and 3~ for non-fma
62-
ret = Low<IS_COS>(x);
61+
ret = Low<OP>(x);
6362
} else {
6463
// High precision: π/16 reduction with table lookup + polynomial
6564
// Error: ~1 ULP
66-
ret = High<IS_COS>(x);
65+
ret = High<OP>(x);
6766
}
6867
// Step 2: Handle special cases (NaN, Inf) if enabled
6968
auto is_finite = IsFinite(x);
7069
if constexpr (Prec::kSpecialCases) {
7170
// IEEE 754 requires: sin(±∞) = NaN, cos(±∞) = NaN
7271
ret = IfThenElse(is_finite, ret, NaN(d));
7372
// -0.0 should return -0.0 for sine
74-
if constexpr (!IS_COS) {
73+
if constexpr (OP == Operation::kSin) {
7574
ret = IfThenElse(Eq(x, Set(d, 0.0)), x, ret);
7675
}
7776
}
@@ -89,7 +88,7 @@ HWY_API V SinCos(Prec &prec, V x) {
8988
if (HWY_UNLIKELY(!AllFalse(d, has_large_arg))) {
9089
// Payne-Hanek reduction: Uses ~96-bit (float) or ~192-bit (double)
9190
// precision for 4/π to maintain accuracy for huge arguments
92-
ret = IfThenElse(has_large_arg, Extended<IS_COS>(x), ret);
91+
ret = IfThenElse(has_large_arg, Extended<OP>(x), ret);
9392
}
9493
}
9594
// Step 4: Raise invalid operation exception for infinity inputs
@@ -107,41 +106,45 @@ namespace npsr::HWY_NAMESPACE {
107106
/**
108107
* @brief Compute sine of vector elements with configurable precision
109108
*
110-
* @tparam Prec Precise configuration (e.g., Precise<kLowAccuracy>)
109+
* @tparam Prec Precise configuration (e.g., Precise{kLowAccuracy})
111110
* @tparam V Highway vector type
112111
* @param prec Precise object managing FP environment
113112
* @param x Input vector
114113
* @return sin(x) for each element
115114
*
116115
* @example
117116
* ```cpp
118-
* Precise<kHighAccuracy> prec;
117+
* Precise prec{
118+
* kLowAccuracy, kNoLargeArgument, kNoExceptions, kNoSpecialCases
119+
* };
119120
* auto result = Sin(prec, input_vector);
120121
* ```
121122
*/
122123
template <typename Prec, typename V>
123-
HWY_API V Sin(Prec &prec, V x) {
124-
return trig::SinCos<false>(prec, x);
124+
NPSR_INTRIN V Sin(Prec &prec, V x) {
125+
return trig::Trig<trig::Operation::kSin>(prec, x);
125126
}
126127

127128
/**
128129
* @brief Compute cosine of vector elements with configurable precision
129130
*
130-
* @tparam Prec Precise configuration (e.g., Precise<kLowAccuracy>)
131+
* @tparam Prec Precise configuration (e.g., Precise{kLowAccuracy})
131132
* @tparam V Highway vector type
132133
* @param prec Precise object managing FP environment
133134
* @param x Input vector
134135
* @return cos(x) for each element
135136
*
136137
* @example
137138
* ```cpp
138-
* Precise<kNoLargeArgument, kNoSpecialCases> prec;
139+
* Precise prec{
140+
* kLowAccuracy, kNoLargeArgument, kNoExceptions, kNoSpecialCases
141+
* };
139142
* auto result = Cos(prec, input_vector);
140143
* ```
141144
*/
142145
template <typename Prec, typename V>
143-
HWY_API V Cos(Prec &prec, V x) {
144-
return trig::SinCos<true>(prec, x);
146+
NPSR_INTRIN V Cos(Prec &prec, V x) {
147+
return trig::Trig<trig::Operation::kCos>(prec, x);
145148
}
146149

147150
} // namespace npsr::HWY_NAMESPACE

npsr/trig/low-inl.h

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@ HWY_BEFORE_NAMESPACE();
1212

1313
namespace npsr::HWY_NAMESPACE::trig {
1414

15-
template <bool IS_COS, typename V, HWY_IF_F32(TFromV<V>)>
16-
HWY_API V PolyLow(V r, V r2) {
15+
enum class Operation { kSin = 0, kCos = 1 };
16+
17+
template <Operation OP, typename V, HWY_IF_F32(TFromV<V>)>
18+
NPSR_INTRIN V PolyLow(V r, V r2) {
1719
using namespace hn;
1820

1921
const DFromV<V> d;
20-
const V c9 = Set(d, IS_COS ? 0x1.5d866ap-19f : 0x1.5dbdfp-19f);
21-
const V c7 = Set(d, IS_COS ? -0x1.9f6d9ep-13 : -0x1.9f6ffep-13f);
22-
const V c5 = Set(d, IS_COS ? 0x1.110ec8p-7 : 0x1.110eccp-7f);
22+
constexpr bool kCos = OP == Operation::kCos;
23+
const V c9 = Set(d, kCos ? 0x1.5d866ap-19f : 0x1.5dbdfp-19f);
24+
const V c7 = Set(d, kCos ? -0x1.9f6d9ep-13 : -0x1.9f6ffep-13f);
25+
const V c5 = Set(d, kCos ? 0x1.110ec8p-7 : 0x1.110eccp-7f);
2326
const V c3 = Set(d, -0x1.55554cp-3f);
2427
V poly = MulAdd(c9, r2, c7);
2528
poly = MulAdd(r2, poly, c5);
2629
poly = MulAdd(r2, poly, c3);
27-
if constexpr (IS_COS) {
30+
if constexpr (OP == Operation::kCos) {
2831
// Although this path handles cosine, we have already transformed the
2932
// input using the identity: cos(x) = sin(x + π/2) This means we're no
3033
// longer directly evaluating a cosine Taylor series; instead, we evaluate
@@ -48,8 +51,8 @@ HWY_API V PolyLow(V r, V r2) {
4851
return poly;
4952
}
5053

51-
template <bool IS_COS, typename V, HWY_IF_F64(TFromV<V>)>
52-
HWY_API V PolyLow(V r, V r2) {
54+
template <Operation OP, typename V, HWY_IF_F64(TFromV<V>)>
55+
NPSR_INTRIN V PolyLow(V r, V r2) {
5356
using namespace hn;
5457

5558
const DFromV<V> d;
@@ -69,8 +72,8 @@ HWY_API V PolyLow(V r, V r2) {
6972
return poly;
7073
}
7174

72-
template <bool IS_COS, typename V>
73-
HWY_API V Low(V x) {
75+
template <Operation OP, typename V>
76+
NPSR_INTRIN V Low(V x) {
7477
using namespace hn;
7578
using hwy::SignMask;
7679
namespace data = ::npsr::trig::data;
@@ -87,7 +90,7 @@ HWY_API V Low(V x) {
8790
// Transform cosine to sine using identity: cos(x) = sin(x + π/2)
8891
const V half_pi = Set(d, data::kHalfPi<T>);
8992
V x_trans = x_abs;
90-
if constexpr (IS_COS) {
93+
if constexpr (OP == Operation::kCos) {
9194
x_trans = Add(x_abs, half_pi);
9295
}
9396
// check zero input/subnormal for cosine (cos(~0) = 1)
@@ -100,7 +103,7 @@ HWY_API V Low(V x) {
100103
V n = Sub(n_biased, magic_round);
101104

102105
// Adjust quotient for cosine (accounts for π/2 phase shift)
103-
if constexpr (IS_COS) {
106+
if constexpr (OP == Operation::kCos) {
104107
// For cosine, we computed N = round((x + π/2)/π) but need N' for x:
105108
// N = round((x + π/2)/π) = round(x/π + 0.5)
106109
// This is often 1 more than round(x/π), so we subtract 0.5:
@@ -124,7 +127,7 @@ HWY_API V Low(V x) {
124127
r = r_lo;
125128
}
126129
V r2 = Mul(r, r);
127-
V poly = PolyLow<IS_COS>(r, r2);
130+
V poly = PolyLow<OP>(r, r2);
128131

129132
if constexpr (!kIsSingle) {
130133
V r2_corr = Mul(r2, r_lo);
@@ -134,7 +137,7 @@ HWY_API V Low(V x) {
134137
// Extract octant sign information from quotient and flip the sign bit
135138
poly = Xor(poly,
136139
BitCast(d, ShiftLeft<sizeof(T) * 8 - 1>(BitCast(du, n_biased))));
137-
if constexpr (IS_COS) {
140+
if constexpr (OP == Operation::kCos) {
138141
poly = IfThenElse(is_cos_near_zero, Set(d, static_cast<T>(1.0)), poly);
139142
} else {
140143
// Restore original sign for sine (odd function)

0 commit comments

Comments
 (0)