|  | 
|  | 1 | +#if defined(NPSR_LUT_INL_H_) == defined(HWY_TARGET_TOGGLE)  // NOLINT | 
|  | 2 | +#ifdef NPSR_LUT_INL_H_ | 
|  | 3 | +#undef NPSR_LUT_INL_H_ | 
|  | 4 | +#else | 
|  | 5 | +#define NPSR_LUT_INL_H_ | 
|  | 6 | +#endif | 
|  | 7 | + | 
|  | 8 | +#include "npsr/hwy.h" | 
|  | 9 | + | 
|  | 10 | +HWY_BEFORE_NAMESPACE(); | 
|  | 11 | + | 
|  | 12 | +namespace npsr::HWY_NAMESPACE { | 
|  | 13 | + | 
|  | 14 | +/** | 
|  | 15 | + * @brief SIMD-optimized lookup table implementation | 
|  | 16 | + * | 
|  | 17 | + * This class provides an efficient lookup table. | 
|  | 18 | + * It stores data in both row-major and column-major | 
|  | 19 | + * formats to optimize different access patterns. | 
|  | 20 | + * | 
|  | 21 | + * @tparam T Element type (must match SIMD vector element type) | 
|  | 22 | + * @tparam kRows Number of rows in the lookup table | 
|  | 23 | + * @tparam kCols Number of columns in the lookup table | 
|  | 24 | + * | 
|  | 25 | + * Example usage: | 
|  | 26 | + * @code | 
|  | 27 | + *   // Create a 2x4 lookup table | 
|  | 28 | + *   constexpr Lut lut{{1.0f, 2.0f, 3.0f, 4.0f}, {5.0f, 6.0f, 7.0f, 8.0f}}; | 
|  | 29 | + *   // Load values using SIMD indices | 
|  | 30 | + *   auto indices = Set(d, 2);  // SIMD vector of indices | 
|  | 31 | + *   Vec<D> out0, out1; | 
|  | 32 | + *   lut.Load(indices, out0, out1); | 
|  | 33 | + * @endcode | 
|  | 34 | + */ | 
|  | 35 | +template <typename T, size_t kRows, size_t kCols> | 
|  | 36 | +class Lut { | 
|  | 37 | + public: | 
|  | 38 | +  static constexpr size_t kLength = kRows * kCols; | 
|  | 39 | + | 
|  | 40 | +  /** | 
|  | 41 | +   * @brief Construct a lookup table from row arrays | 
|  | 42 | +   * | 
|  | 43 | +   * @tparam ColSizes Size of each row array (deduced) | 
|  | 44 | +   * @param rows Variable number of arrays, each representing a row | 
|  | 45 | +   * | 
|  | 46 | +   * @note All rows must have exactly kCols elements | 
|  | 47 | +   * @note The constructor is constexpr for compile-time initialization | 
|  | 48 | +   */ | 
|  | 49 | +  template <size_t... ColSizes> | 
|  | 50 | +  constexpr Lut(const T (&...rows)[ColSizes]) : row_{} { | 
|  | 51 | +    // Check that we have the right number of rows | 
|  | 52 | +    static_assert(sizeof...(rows) == kRows, | 
|  | 53 | +                  "Number of rows doesn't match template parameter"); | 
|  | 54 | +    // Check that all rows have the same number of columns | 
|  | 55 | +    static_assert(((ColSizes == kCols) && ...), | 
|  | 56 | +                  "All rows must have the same number of columns"); | 
|  | 57 | + | 
|  | 58 | +    // Copy data using recursive template approach | 
|  | 59 | +    ToRowMajor_<0>(rows...); | 
|  | 60 | +  } | 
|  | 61 | + | 
|  | 62 | +  /** | 
|  | 63 | +   * @brief Load values from the LUT using SIMD indices | 
|  | 64 | +   * | 
|  | 65 | +   * This method performs efficient SIMD lookups by selecting the optimal | 
|  | 66 | +   * implementation based on the vector size and LUT dimensions. | 
|  | 67 | +   * | 
|  | 68 | +   * @tparam VU SIMD vector type for indices | 
|  | 69 | +   * @tparam OutV Output SIMD vector types (must match number of rows) | 
|  | 70 | +   * @param idx SIMD vector of column indices | 
|  | 71 | +   * @param out Output vectors (one per row) | 
|  | 72 | +   * | 
|  | 73 | +   * @note The number of output vectors must exactly match kRows | 
|  | 74 | +   * @note Index values must be in range [0, kCols) | 
|  | 75 | +   */ | 
|  | 76 | +  template <typename VU, typename... OutV> | 
|  | 77 | +  HWY_ATTR void Load(VU idx, OutV &...out) const { | 
|  | 78 | +    static_assert(sizeof...(OutV) == kRows, | 
|  | 79 | +                  "Number of output vectors must match number of rows in LUT"); | 
|  | 80 | +    using namespace hn; | 
|  | 81 | +    using TU = TFromV<VU>; | 
|  | 82 | +    static_assert(sizeof(TU) == sizeof(T), | 
|  | 83 | +                  "Index type must match LUT element type"); | 
|  | 84 | +    // Row-major based optimization | 
|  | 85 | +    LoadRow_(idx, out...); | 
|  | 86 | +  } | 
|  | 87 | + | 
|  | 88 | + private: | 
|  | 89 | +  /// Convert input rows to row-major storage format | 
|  | 90 | +  template <size_t RowIDX, size_t... ColSizes> | 
|  | 91 | +  constexpr void ToRowMajor_(const T (&...rows)[ColSizes]) { | 
|  | 92 | +    if constexpr (RowIDX < kRows) { | 
|  | 93 | +      auto row_array = std::get<RowIDX>(std::make_tuple(rows...)); | 
|  | 94 | +      for (size_t col = 0; col < kCols; ++col) { | 
|  | 95 | +        row_[RowIDX * kCols + col] = row_array[col]; | 
|  | 96 | +      } | 
|  | 97 | +      ToRowMajor_<RowIDX + 1>(rows...); | 
|  | 98 | +    } | 
|  | 99 | +  } | 
|  | 100 | + | 
|  | 101 | +  /// Dispatch to optimal row-load implementation based on vector/LUT size | 
|  | 102 | +  template <size_t Off = 0, typename VU, typename... OutV> | 
|  | 103 | +  HWY_ATTR void LoadRow_(VU idx, OutV &...out) const { | 
|  | 104 | +    using namespace hn; | 
|  | 105 | +    using DU = DFromV<VU>; | 
|  | 106 | +    const DU du; | 
|  | 107 | +    using D = Rebind<T, DU>; | 
|  | 108 | +    const D d; | 
|  | 109 | + | 
|  | 110 | +#if !HWY_HAVE_SCALABLE | 
|  | 111 | +    constexpr size_t kLanes = Lanes(du); | 
|  | 112 | +    if constexpr (kLanes == kCols) { | 
|  | 113 | +      // Vector size matches table width - use single table lookup | 
|  | 114 | +      const auto ind = IndicesFromVec(d, idx); | 
|  | 115 | +      LoadX1_<Off>(ind, out...); | 
|  | 116 | +    } else if constexpr (kLanes * 2 == kCols) { | 
|  | 117 | +      // Vector size is half table width - use two table lookup | 
|  | 118 | +      const auto ind = IndicesFromVec(d, idx); | 
|  | 119 | +      LoadX2_<Off>(ind, out...); | 
|  | 120 | +    } | 
|  | 121 | +#else | 
|  | 122 | +    if constexpr (0) { | 
|  | 123 | +    } | 
|  | 124 | +#endif | 
|  | 125 | +    else { | 
|  | 126 | +      // Fallback to gather for other configurations | 
|  | 127 | +      LoadGather_<Off>(idx, out...); | 
|  | 128 | +    } | 
|  | 129 | +  } | 
|  | 130 | + | 
|  | 131 | +  // Load using single table lookup (vector size == table width) | 
|  | 132 | +  template <size_t Off = 0, typename VInd, typename OutV0, typename... OutV> | 
|  | 133 | +  HWY_ATTR void LoadX1_(const VInd &ind, OutV0 &out0, OutV &...out) const { | 
|  | 134 | +    using namespace hn; | 
|  | 135 | +    using D = DFromV<OutV0>; | 
|  | 136 | +    const D d; | 
|  | 137 | + | 
|  | 138 | +    const OutV0 lut0 = Load(d, row_ + Off); | 
|  | 139 | +    out0 = TableLookupLanes(d, lut0, ind); | 
|  | 140 | + | 
|  | 141 | +    if constexpr (sizeof...(OutV) > 0) { | 
|  | 142 | +      LoadX1_<Off + kCols>(ind, out...); | 
|  | 143 | +    } | 
|  | 144 | +  } | 
|  | 145 | + | 
|  | 146 | +  // Load using two table lookups (vector size == table width / 2) | 
|  | 147 | +  template <size_t Off = 0, typename VInd, typename OutV0, typename... OutV> | 
|  | 148 | +  HWY_ATTR void LoadX2_(const VInd &ind, OutV0 &out0, OutV &...out) const { | 
|  | 149 | +    using namespace hn; | 
|  | 150 | +    using D = DFromV<OutV0>; | 
|  | 151 | +    const D d; | 
|  | 152 | + | 
|  | 153 | +    constexpr size_t kLanes = kCols / 2; | 
|  | 154 | +    const OutV0 lut0 = LoadU(d, row_ + Off); | 
|  | 155 | +    const OutV0 lut1 = LoadU(d, row_ + Off + kLanes); | 
|  | 156 | +    out0 = TwoTablesLookupLanes(d, lut0, lut1, ind); | 
|  | 157 | + | 
|  | 158 | +    if constexpr (sizeof...(OutV) > 0) { | 
|  | 159 | +      LoadX2_<Off + kCols>(ind, out...); | 
|  | 160 | +    } | 
|  | 161 | +  } | 
|  | 162 | + | 
|  | 163 | +  //  General fallback using gather instructions | 
|  | 164 | +  template <size_t Off = 0, typename VU, typename OutV0, typename... OutV> | 
|  | 165 | +  HWY_ATTR void LoadGather_(const VU &idx, OutV0 &out0, OutV &...out) const { | 
|  | 166 | +    using namespace hn; | 
|  | 167 | +    using D = DFromV<OutV0>; | 
|  | 168 | +    const D d; | 
|  | 169 | +    out0 = GatherIndex(d, row_ + Off, BitCast(RebindToSigned<D>(), idx)); | 
|  | 170 | +    if constexpr (sizeof...(OutV) > 0) { | 
|  | 171 | +      LoadGather_<Off + kCols>(idx, out...); | 
|  | 172 | +    } | 
|  | 173 | +  } | 
|  | 174 | + | 
|  | 175 | +  // Row-major | 
|  | 176 | +  HWY_ALIGN T row_[kLength]; | 
|  | 177 | +}; | 
|  | 178 | + | 
|  | 179 | +/** | 
|  | 180 | + * @brief Deduction guide for automatic dimension detection | 
|  | 181 | + * | 
|  | 182 | + * Allows constructing a Lut without explicitly specifying dimensions: | 
|  | 183 | + * @code | 
|  | 184 | + *   Lut lut{row0, row1, row2};  // Dimensions deduced from arrays | 
|  | 185 | + * @endcode | 
|  | 186 | + */ | 
|  | 187 | +template <typename T, size_t First, size_t... Rest> | 
|  | 188 | +Lut(const T (&first)[First], const T (&...rest)[Rest]) | 
|  | 189 | +    -> Lut<T, 1 + sizeof...(Rest), First>; | 
|  | 190 | + | 
|  | 191 | +/** | 
|  | 192 | + * @brief Factory function that requires explicit type specification | 
|  | 193 | + * | 
|  | 194 | + * This approach forces users to specify the type T explicitly while | 
|  | 195 | + * automatically deducing the dimensions from the array arguments. | 
|  | 196 | + * | 
|  | 197 | + * Note: We use MakeLut since partial deduction guides (e.g., Lut<float>{...}) | 
|  | 198 | + * require C++20, but this codebase targets C++17. | 
|  | 199 | + * | 
|  | 200 | + * @tparam T Element type (must be explicitly specified) | 
|  | 201 | + * @param first First row array | 
|  | 202 | + * @param rest Additional row arrays | 
|  | 203 | + * @return Lut with deduced dimensions | 
|  | 204 | + * | 
|  | 205 | + * Usage: | 
|  | 206 | + * @code | 
|  | 207 | + *   auto lut = MakeLut<float>(row0, row1, row2);  // T explicit, dimensions | 
|  | 208 | + * deduced | 
|  | 209 | + * @endcode | 
|  | 210 | + */ | 
|  | 211 | +template <typename T, size_t First, size_t... Rest> | 
|  | 212 | +constexpr auto MakeLut(const T (&first)[First], const T (&...rest)[Rest]) { | 
|  | 213 | +  return Lut<T, 1 + sizeof...(Rest), First>{first, rest...}; | 
|  | 214 | +} | 
|  | 215 | + | 
|  | 216 | +}  // namespace npsr::HWY_NAMESPACE | 
|  | 217 | + | 
|  | 218 | +HWY_AFTER_NAMESPACE(); | 
|  | 219 | + | 
|  | 220 | +#endif  // NPSR_LUT_INL_H_ | 
0 commit comments