From 2868c3ad85367df402a0f3d844afbacf51d667a3 Mon Sep 17 00:00:00 2001 From: Oleksandr Kulkov Date: Tue, 3 Dec 2024 02:25:17 +0100 Subject: [PATCH] Increase FFT precision, get rid of manual vectorization, use partial fft --- cp-algo/math/cvector.hpp | 197 +++++++++++++++++----------------- cp-algo/math/fft.hpp | 58 +++++----- verify/poly/wildcard.test.cpp | 12 +-- 3 files changed, 131 insertions(+), 136 deletions(-) diff --git a/cp-algo/math/cvector.hpp b/cp-algo/math/cvector.hpp index 203b1d1..f5933e3 100644 --- a/cp-algo/math/cvector.hpp +++ b/cp-algo/math/cvector.hpp @@ -1,133 +1,128 @@ #ifndef CP_ALGO_MATH_CVECTOR_HPP #define CP_ALGO_MATH_CVECTOR_HPP -#include "../util/complex.hpp" -#include +#include +#include +#include +#include namespace cp_algo::math::fft { using ftype = double; - using point = complex; - using vftype = std::experimental::native_simd; - using vpoint = complex; - static constexpr size_t flen = vftype::size(); + using point = std::complex; - struct cvector { - static constexpr size_t pre_roots = 1 << 18; - std::vector x, y; - cvector(size_t n) { - n = std::max(flen, std::bit_ceil(n)); - x.resize(n / flen); - y.resize(n / flen); + struct ftvec: std::vector { + static constexpr size_t pre_roots = 1 << 16; + static constexpr size_t threshold = 32; + ftvec(size_t n) { + this->resize(std::max(threshold, std::bit_ceil(n))); } - template - void set(size_t k, pt t) { - if constexpr(std::is_same_v) { - x[k / flen][k % flen] = real(t); - y[k / flen][k % flen] = imag(t); - } else { - x[k / flen] = real(t); - y[k / flen] = imag(t); + static auto dot_block(size_t k, ftvec const& A, ftvec const& B) { + static std::array r; + std::ranges::fill(r, point(0)); + for(size_t i = 0; i < threshold; i++) { + for(size_t j = 0; j < threshold; j++) { + r[i + j] += A[k + i] * B[k + j]; + } } - } - template - pt get(size_t k) const { - if constexpr(std::is_same_v) { - return {x[k / flen][k % flen], y[k / flen][k % flen]}; - } else { - return {x[k / flen], y[k / flen]}; + auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2); + static std::array res; + for(size_t i = 0; i < threshold; i++) { + res[i] = r[i] + r[i + threshold] * rt; } - } - vpoint vget(size_t k) const { - return get(k); + return res; } - size_t size() const { - return flen * std::size(x); + void dot(ftvec const& t) { + size_t n = this->size(); + for(size_t k = 0; k < n; k += threshold) { + std::ranges::copy(dot_block(k, *this, t), this->begin() + k); + } } - void dot(cvector const& t) { - size_t n = size(); - for(size_t k = 0; k < n; k += flen) { - set(k, get(k) * t.get(k)); + static std::array roots, evalp; + static std::array eval_args; + static point root(size_t n, size_t k) { + if(n + k < pre_roots && roots[n + k] != point{}) { + return roots[n + k]; } + auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n)); + if(n + k < pre_roots) { + roots[n + k] = res; + } + return res; } - static const cvector roots; - template< bool precalc = false, class ft = point> - static auto root(size_t n, size_t k, ft &&arg) { - if(n < pre_roots && !precalc) { - return roots.get>(n + k); - } else { - return complex::polar(1., arg); + static size_t eval_arg(size_t n) { + if(n < pre_roots && eval_args[n]) { + return eval_args[n]; + } else if(n == 0) { + return 0; + } + auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1); + if(n < pre_roots) { + eval_args[n] = res; } + return res; + } + static point eval_point(size_t n) { + if(n < pre_roots && evalp[n] != point{}) { + return evalp[n]; + } else if(n == 0) { + return point(1); + } + auto res = root(2 * std::bit_floor(n), eval_arg(n)); + if(n < pre_roots) { + evalp[n] = res; + } + return res; } - template static void exec_on_roots(size_t n, size_t m, auto &&callback) { - ftype arg = std::numbers::pi / (ftype)n; - size_t step = sizeof(pt) / sizeof(point); - using ft = pt::value_type; - auto k = [&]() { - if constexpr(std::is_same_v) { - return ft{}; - } else { - return ft{[](auto i) {return i;}}; + auto step = root(n, 1); + auto rt = point(1); + for(size_t i = 0; i < m; i++) { + if(i % threshold == 0) { + rt = root(n / threshold, i / threshold); } - }(); - for(size_t i = 0; i < m; i += step, k += (ftype)step) { - callback(i, root(n, i, arg * k)); + callback(i, rt); + rt *= step; + } + } + static void exec_on_evals(size_t n, auto &&callback) { + for(size_t i = 0; i < n; i++) { + callback(i, eval_point(i)); } } void ifft() { - size_t n = size(); - for(size_t i = 1; i < n; i *= 2) { - for(size_t j = 0; j < n; j += 2 * i) { - auto butterfly = [&](size_t k, pt rt) { - k += j; - auto t = get(k + i) * conj(rt); - set(k + i, get(k) - t); - set(k, get(k) + t); - }; - if(i < flen) { - exec_on_roots(i, i, butterfly); - } else { - exec_on_roots(i, i, butterfly); + size_t n = this->size(); + for(size_t half = threshold; half <= n / 2; half *= 2) { + exec_on_evals(n / (2 * half), [&](size_t k, point rt) { + k *= 2 * half; + for(size_t j = k; j < k + half; j++) { + auto A = this->at(j) + this->at(j + half); + auto B = this->at(j) - this->at(j + half); + this->at(j) = A; + this->at(j + half) = B * conj(rt); } - } + }); } - for(size_t k = 0; k < n; k += flen) { - set(k, get(k) /= (ftype)n); + point ni = point(int(threshold)) / point(int(n)); + for(auto &it: *this) { + it *= ni; } } void fft() { - size_t n = size(); - for(size_t i = n / 2; i >= 1; i /= 2) { - for(size_t j = 0; j < n; j += 2 * i) { - auto butterfly = [&](size_t k, pt rt) { - k += j; - auto A = get(k) + get(k + i); - auto B = get(k) - get(k + i); - set(k, A); - set(k + i, B * rt); - }; - if(i < flen) { - exec_on_roots(i, i, butterfly); - } else { - exec_on_roots(i, i, butterfly); + size_t n = this->size(); + for(size_t half = n / 2; half >= threshold; half /= 2) { + exec_on_evals(n / (2 * half), [&](size_t k, point rt) { + k *= 2 * half; + for(size_t j = k; j < k + half; j++) { + auto t = this->at(j + half) * rt; + this->at(j + half) = this->at(j) - t; + this->at(j) += t; } - } + }); } } }; - const cvector cvector::roots = []() { - cvector res(pre_roots); - for(size_t n = 1; n < res.size(); n *= 2) { - auto propagate = [&](size_t k, auto rt) { - res.set(n + k, rt); - }; - if(n < flen) { - res.exec_on_roots(n, n, propagate); - } else { - res.exec_on_roots(n, n, propagate); - } - } - return res; - }(); + std::array ftvec::roots = {}; + std::array ftvec::evalp = {}; + std::array ftvec::eval_args = {}; } #endif // CP_ALGO_MATH_CVECTOR_HPP diff --git a/cp-algo/math/fft.hpp b/cp-algo/math/fft.hpp index 2416327..202ce4d 100644 --- a/cp-algo/math/fft.hpp +++ b/cp-algo/math/fft.hpp @@ -2,15 +2,14 @@ #define CP_ALGO_MATH_FFT_HPP #include "../number_theory/modint.hpp" #include "cvector.hpp" -#include namespace cp_algo::math::fft { template struct dft { - cvector A; + ftvec A; dft(std::vector const& a, size_t n): A(n) { for(size_t i = 0; i < std::min(n, a.size()); i++) { - A.set(i, a[i]); + A[i] = a[i]; } if(n) { A.fft(); @@ -27,7 +26,7 @@ namespace cp_algo::math::fft { A.ifft(); std::vector res(n); for(size_t k = 0; k < n; k++) { - res[k] = A.get(k); + res[k] = A[k]; } return res; } @@ -35,22 +34,22 @@ namespace cp_algo::math::fft { auto operator * (dft const& B) const { return dft(*this) *= B; } - - point operator [](int i) const {return A.get(i);} }; template struct dft { int split; - cvector A, B; + ftvec A, B; dft(auto const& a, size_t n): A(n), B(n) { - split = int(std::sqrt(base::mod())); - cvector::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) { + n = size(A); + split = int(std::sqrt(base::mod())) + 1; + ftvec::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) { size_t ti = std::min(i, i - n); - A.set(ti, A.get(ti) + ftype(a[i].rem() % split) * rt); - B.set(ti, B.get(ti) + ftype(a[i].rem() / split) * rt); - + auto rem = std::remainder(a[i].rem(), split); + auto quo = (a[i].rem() - rem) / split; + A[ti] += rem * rt; + B[ti] += quo * rt; }); if(n) { A.fft(); @@ -65,21 +64,26 @@ namespace cp_algo::math::fft { res = {}; return; } - for(size_t i = 0; i < n; i += flen) { - auto tmp = A.vget(i) * D.vget(i) + B.vget(i) * C.vget(i); - A.set(i, A.vget(i) * C.vget(i)); - B.set(i, B.vget(i) * D.vget(i)); - C.set(i, tmp); + for(size_t i = 0; i < n; i += ftvec::threshold) { + auto AC = ftvec::dot_block(i, A, C); + auto AD = ftvec::dot_block(i, A, D); + auto BC = ftvec::dot_block(i, B, C); + auto BD = ftvec::dot_block(i, B, D); + for(size_t j = 0; j < ftvec::threshold; j++) { + A[i + j] = AC[j]; + C[i + j] = AD[j] + BC[j]; + B[i + j] = BD[j]; + } } A.ifft(); B.ifft(); C.ifft(); auto splitsplit = (base(split) * split).rem(); - cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) { + ftvec::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) { rt = conj(rt); - auto Ai = A.get(i) * rt; - auto Bi = B.get(i) * rt; - auto Ci = C.get(i) * rt; + auto Ai = A[i] * rt; + auto Bi = B[i] * rt; + auto Ci = C[i] * rt; int64_t A0 = llround(real(Ai)); int64_t A1 = llround(real(Ci)); int64_t A2 = llround(real(Bi)); @@ -97,7 +101,7 @@ namespace cp_algo::math::fft { mul(B.A, B.B, res, k); } void mul(auto const& B, auto& res, size_t k) { - mul(cvector(B.A), B.B, res, k); + mul(ftvec(B.A), B.B, res, k); } std::vector operator *= (dft &B) { std::vector res(2 * A.size()); @@ -112,8 +116,6 @@ namespace cp_algo::math::fft { auto operator * (dft const& B) const { return dft(*this) *= B; } - - point operator [](int i) const {return A.get(i);} }; void mul_slow(auto &a, auto const& b, size_t k) { @@ -135,17 +137,15 @@ namespace cp_algo::math::fft { if(!as || !bs) { return 0; } - return std::max(flen, std::bit_ceil(as + bs - 1) / 2); + return std::bit_ceil(as + bs - 1) / 2; } void mul_truncate(auto &a, auto const& b, size_t k) { using base = std::decay_t; - if(std::min({k, size(a), size(b)}) < magic) { + if(std::min({k, size(a), size(b)}) < 1) { mul_slow(a, b, k); return; } - auto n = std::max(flen, std::bit_ceil( - std::min(k, size(a)) + std::min(k, size(b)) - 1 - ) / 2); + auto n = std::bit_ceil(std::min(k, size(a)) + std::min(k, size(b)) - 1) / 2; a.resize(k); auto A = dft(a, n); if(&a == &b) { diff --git a/verify/poly/wildcard.test.cpp b/verify/poly/wildcard.test.cpp index 996c783..5579da4 100644 --- a/verify/poly/wildcard.test.cpp +++ b/verify/poly/wildcard.test.cpp @@ -8,9 +8,9 @@ using namespace std; using namespace cp_algo::math; -using fft::ftype; -using fft::point; -using fft::cvector; +using ftype = double; +using point = complex; +using cvector = fft::ftvec; void semicorr(auto &a, auto &b) { a.fft(); @@ -32,7 +32,7 @@ string matches(string const& A, string const& B, char wild = '*') { if(!init) { init = true; for(int i = 0; i < sigma; i++) { - project[0][i] = cp_algo::polar(1., (ftype)cp_algo::random::rng()); + project[0][i] = polar(1., (ftype)cp_algo::random::rng()); project[1][i] = conj(project[0][i]); } } @@ -44,13 +44,13 @@ string matches(string const& A, string const& B, char wild = '*') { char c = ST[i]->at(k); size_t idx = i ? N - k - 1 : k; point val = c == wild ? 0 : project[i][c - 'a']; - P[i].set(idx, val); + P[i][idx] = val; } } semicorr(P[0], P[1]); string ans(size(A) - size(B) + 1, '0'); for(size_t j = 0; j < size(ans); j++) { - ans[j] = '0' + is_integer(P[0].get(size(B) - 1 + j)); + ans[j] = '0' + is_integer(P[0][size(B) - 1 + j]); } return ans; }