diff --git a/cp-algo/math/cvector.hpp b/cp-algo/math/cvector.hpp new file mode 100644 index 0000000..203b1d1 --- /dev/null +++ b/cp-algo/math/cvector.hpp @@ -0,0 +1,133 @@ +#ifndef CP_ALGO_MATH_CVECTOR_HPP +#define CP_ALGO_MATH_CVECTOR_HPP +#include "../util/complex.hpp" +#include +namespace cp_algo::math::fft { + using ftype = double; + using point = complex; + using vftype = std::experimental::native_simd; + using vpoint = complex; + static constexpr size_t flen = vftype::size(); + + struct cvector { + static constexpr size_t pre_roots = 1 << 18; + std::vector x, y; + cvector(size_t n) { + n = std::max(flen, std::bit_ceil(n)); + x.resize(n / flen); + y.resize(n / flen); + } + template + void set(size_t k, pt t) { + if constexpr(std::is_same_v) { + x[k / flen][k % flen] = real(t); + y[k / flen][k % flen] = imag(t); + } else { + x[k / flen] = real(t); + y[k / flen] = imag(t); + } + } + template + pt get(size_t k) const { + if constexpr(std::is_same_v) { + return {x[k / flen][k % flen], y[k / flen][k % flen]}; + } else { + return {x[k / flen], y[k / flen]}; + } + } + vpoint vget(size_t k) const { + return get(k); + } + + size_t size() const { + return flen * std::size(x); + } + void dot(cvector const& t) { + size_t n = size(); + for(size_t k = 0; k < n; k += flen) { + set(k, get(k) * t.get(k)); + } + } + static const cvector roots; + template< bool precalc = false, class ft = point> + static auto root(size_t n, size_t k, ft &&arg) { + if(n < pre_roots && !precalc) { + return roots.get>(n + k); + } else { + return complex::polar(1., arg); + } + } + template + static void exec_on_roots(size_t n, size_t m, auto &&callback) { + ftype arg = std::numbers::pi / (ftype)n; + size_t step = sizeof(pt) / sizeof(point); + using ft = pt::value_type; + auto k = [&]() { + if constexpr(std::is_same_v) { + return ft{}; + } else { + return ft{[](auto i) {return i;}}; + } + }(); + for(size_t i = 0; i < m; i += step, k += (ftype)step) { + callback(i, root(n, i, arg * k)); + } + } + + void ifft() { + size_t n = size(); + for(size_t i = 1; i < n; i *= 2) { + for(size_t j = 0; j < n; j += 2 * i) { + auto butterfly = [&](size_t k, pt rt) { + k += j; + auto t = get(k + i) * conj(rt); + set(k + i, get(k) - t); + set(k, get(k) + t); + }; + if(i < flen) { + exec_on_roots(i, i, butterfly); + } else { + exec_on_roots(i, i, butterfly); + } + } + } + for(size_t k = 0; k < n; k += flen) { + set(k, get(k) /= (ftype)n); + } + } + void fft() { + size_t n = size(); + for(size_t i = n / 2; i >= 1; i /= 2) { + for(size_t j = 0; j < n; j += 2 * i) { + auto butterfly = [&](size_t k, pt rt) { + k += j; + auto A = get(k) + get(k + i); + auto B = get(k) - get(k + i); + set(k, A); + set(k + i, B * rt); + }; + if(i < flen) { + exec_on_roots(i, i, butterfly); + } else { + exec_on_roots(i, i, butterfly); + } + } + } + } + }; + const cvector cvector::roots = []() { + cvector res(pre_roots); + for(size_t n = 1; n < res.size(); n *= 2) { + auto propagate = [&](size_t k, auto rt) { + res.set(n + k, rt); + }; + if(n < flen) { + res.exec_on_roots(n, n, propagate); + } else { + res.exec_on_roots(n, n, propagate); + } + } + return res; + }(); +} +#endif // CP_ALGO_MATH_CVECTOR_HPP diff --git a/cp-algo/math/fft.hpp b/cp-algo/math/fft.hpp index c574cf2..2416327 100644 --- a/cp-algo/math/fft.hpp +++ b/cp-algo/math/fft.hpp @@ -1,142 +1,9 @@ #ifndef CP_ALGO_MATH_FFT_HPP #define CP_ALGO_MATH_FFT_HPP -#include "common.hpp" #include "../number_theory/modint.hpp" -#include "../util/complex.hpp" -#include -#include +#include "cvector.hpp" #include -#include -#include -#include namespace cp_algo::math::fft { - using ftype = double; - using point = complex; - using vftype = std::experimental::native_simd; - using vpoint = complex; - static constexpr size_t flen = vftype::size(); - - struct cvector { - static constexpr size_t pre_roots = 1 << 18; - std::vector x, y; - cvector(size_t n) { - n = std::max(flen, std::bit_ceil(n)); - x.resize(n / flen); - y.resize(n / flen); - } - template - void set(size_t k, pt t) { - if constexpr(std::is_same_v) { - x[k / flen][k % flen] = real(t); - y[k / flen][k % flen] = imag(t); - } else { - x[k / flen] = real(t); - y[k / flen] = imag(t); - } - } - template - pt get(size_t k) const { - if constexpr(std::is_same_v) { - return {x[k / flen][k % flen], y[k / flen][k % flen]}; - } else { - return {x[k / flen], y[k / flen]}; - } - } - vpoint vget(size_t k) const { - return get(k); - } - - size_t size() const { - return flen * std::size(x); - } - void dot(cvector const& t) { - size_t n = size(); - for(size_t k = 0; k < n; k += flen) { - set(k, get(k) * t.get(k)); - } - } - static const cvector roots; - template< bool precalc = false, class ft = point> - static auto root(size_t n, size_t k, ft &&arg) { - if(n < pre_roots && !precalc) { - return roots.get>(n + k); - } else { - return complex::polar(1., arg); - } - } - template - static void exec_on_roots(size_t n, size_t m, auto &&callback) { - ftype arg = std::numbers::pi / (ftype)n; - size_t step = sizeof(pt) / sizeof(point); - using ft = pt::value_type; - auto k = [&]() { - if constexpr(std::is_same_v) { - return ft{}; - } else { - return ft{[](auto i) {return i;}}; - } - }(); - for(size_t i = 0; i < m; i += step, k += (ftype)step) { - callback(i, root(n, i, arg * k)); - } - } - - void ifft() { - size_t n = size(); - for(size_t i = 1; i < n; i *= 2) { - for(size_t j = 0; j < n; j += 2 * i) { - auto butterfly = [&](size_t k, pt rt) { - k += j; - auto t = get(k + i) * conj(rt); - set(k + i, get(k) - t); - set(k, get(k) + t); - }; - if(i < flen) { - exec_on_roots(i, i, butterfly); - } else { - exec_on_roots(i, i, butterfly); - } - } - } - for(size_t k = 0; k < n; k += flen) { - set(k, get(k) /= (ftype)n); - } - } - void fft() { - size_t n = size(); - for(size_t i = n / 2; i >= 1; i /= 2) { - for(size_t j = 0; j < n; j += 2 * i) { - auto butterfly = [&](size_t k, pt rt) { - k += j; - auto A = get(k) + get(k + i); - auto B = get(k) - get(k + i); - set(k, A); - set(k + i, B * rt); - }; - if(i < flen) { - exec_on_roots(i, i, butterfly); - } else { - exec_on_roots(i, i, butterfly); - } - } - } - } - }; - const cvector cvector::roots = []() { - cvector res(pre_roots); - for(size_t n = 1; n < res.size(); n *= 2) { - auto propagate = [&](size_t k, auto rt) { - res.set(n + k, rt); - }; - if(n < flen) { - res.exec_on_roots(n, n, propagate); - } else { - res.exec_on_roots(n, n, propagate); - } - } - return res; - }(); - template struct dft { cvector A; diff --git a/verify/poly/wildcard.test.cpp b/verify/poly/wildcard.test.cpp index 040ee4f..996c783 100644 --- a/verify/poly/wildcard.test.cpp +++ b/verify/poly/wildcard.test.cpp @@ -1,7 +1,7 @@ // @brief Wildcard Pattern Matching #define PROBLEM "https://judge.yosupo.jp/problem/wildcard_pattern_matching" #pragma GCC optimize("Ofast,unroll-loops") -#include "cp-algo/math/fft.hpp" +#include "cp-algo/math/cvector.hpp" #include "cp-algo/random/rng.hpp" #include