Skip to content

Commit

Permalink
Increase FFT precision, get rid of manual vectorization, use partial fft
Browse files Browse the repository at this point in the history
  • Loading branch information
adamant-pwn committed Dec 3, 2024
1 parent ac2a9a0 commit 2868c3a
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 136 deletions.
197 changes: 96 additions & 101 deletions cp-algo/math/cvector.hpp
Original file line number Diff line number Diff line change
@@ -1,133 +1,128 @@
#ifndef CP_ALGO_MATH_CVECTOR_HPP
#define CP_ALGO_MATH_CVECTOR_HPP
#include "../util/complex.hpp"
#include <experimental/simd>
#include <algorithm>
#include <complex>
#include <vector>
#include <ranges>
namespace cp_algo::math::fft {
using ftype = double;
using point = complex<ftype>;
using vftype = std::experimental::native_simd<ftype>;
using vpoint = complex<vftype>;
static constexpr size_t flen = vftype::size();
using point = std::complex<ftype>;

struct cvector {
static constexpr size_t pre_roots = 1 << 18;
std::vector<vftype> x, y;
cvector(size_t n) {
n = std::max(flen, std::bit_ceil(n));
x.resize(n / flen);
y.resize(n / flen);
struct ftvec: std::vector<point> {
static constexpr size_t pre_roots = 1 << 16;
static constexpr size_t threshold = 32;
ftvec(size_t n) {
this->resize(std::max(threshold, std::bit_ceil(n)));
}
template<class pt = point>
void set(size_t k, pt t) {
if constexpr(std::is_same_v<pt, point>) {
x[k / flen][k % flen] = real(t);
y[k / flen][k % flen] = imag(t);
} else {
x[k / flen] = real(t);
y[k / flen] = imag(t);
static auto dot_block(size_t k, ftvec const& A, ftvec const& B) {
static std::array<point, 2 * threshold> r;
std::ranges::fill(r, point(0));
for(size_t i = 0; i < threshold; i++) {
for(size_t j = 0; j < threshold; j++) {
r[i + j] += A[k + i] * B[k + j];
}
}
}
template<class pt = point>
pt get(size_t k) const {
if constexpr(std::is_same_v<pt, point>) {
return {x[k / flen][k % flen], y[k / flen][k % flen]};
} else {
return {x[k / flen], y[k / flen]};
auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2);
static std::array<point, threshold> res;
for(size_t i = 0; i < threshold; i++) {
res[i] = r[i] + r[i + threshold] * rt;
}
}
vpoint vget(size_t k) const {
return get<vpoint>(k);
return res;
}

size_t size() const {
return flen * std::size(x);
void dot(ftvec const& t) {
size_t n = this->size();
for(size_t k = 0; k < n; k += threshold) {
std::ranges::copy(dot_block(k, *this, t), this->begin() + k);
}
}
void dot(cvector const& t) {
size_t n = size();
for(size_t k = 0; k < n; k += flen) {
set(k, get<vpoint>(k) * t.get<vpoint>(k));
static std::array<point, pre_roots> roots, evalp;
static std::array<size_t, pre_roots> eval_args;
static point root(size_t n, size_t k) {
if(n + k < pre_roots && roots[n + k] != point{}) {
return roots[n + k];
}
auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n));
if(n + k < pre_roots) {
roots[n + k] = res;
}
return res;
}
static const cvector roots;
template< bool precalc = false, class ft = point>
static auto root(size_t n, size_t k, ft &&arg) {
if(n < pre_roots && !precalc) {
return roots.get<complex<ft>>(n + k);
} else {
return complex<ft>::polar(1., arg);
static size_t eval_arg(size_t n) {
if(n < pre_roots && eval_args[n]) {
return eval_args[n];
} else if(n == 0) {
return 0;
}
auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
if(n < pre_roots) {
eval_args[n] = res;
}
return res;
}
static point eval_point(size_t n) {
if(n < pre_roots && evalp[n] != point{}) {
return evalp[n];
} else if(n == 0) {
return point(1);
}
auto res = root(2 * std::bit_floor(n), eval_arg(n));
if(n < pre_roots) {
evalp[n] = res;
}
return res;
}
template<class pt = point, bool precalc = false>
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
ftype arg = std::numbers::pi / (ftype)n;
size_t step = sizeof(pt) / sizeof(point);
using ft = pt::value_type;
auto k = [&]() {
if constexpr(std::is_same_v<pt, point>) {
return ft{};
} else {
return ft{[](auto i) {return i;}};
auto step = root(n, 1);
auto rt = point(1);
for(size_t i = 0; i < m; i++) {
if(i % threshold == 0) {
rt = root(n / threshold, i / threshold);
}
}();
for(size_t i = 0; i < m; i += step, k += (ftype)step) {
callback(i, root<precalc>(n, i, arg * k));
callback(i, rt);
rt *= step;
}
}
static void exec_on_evals(size_t n, auto &&callback) {
for(size_t i = 0; i < n; i++) {
callback(i, eval_point(i));
}
}

void ifft() {
size_t n = size();
for(size_t i = 1; i < n; i *= 2) {
for(size_t j = 0; j < n; j += 2 * i) {
auto butterfly = [&]<class pt>(size_t k, pt rt) {
k += j;
auto t = get<pt>(k + i) * conj(rt);
set(k + i, get<pt>(k) - t);
set(k, get<pt>(k) + t);
};
if(i < flen) {
exec_on_roots<point>(i, i, butterfly);
} else {
exec_on_roots<vpoint>(i, i, butterfly);
size_t n = this->size();
for(size_t half = threshold; half <= n / 2; half *= 2) {
exec_on_evals(n / (2 * half), [&](size_t k, point rt) {
k *= 2 * half;
for(size_t j = k; j < k + half; j++) {
auto A = this->at(j) + this->at(j + half);
auto B = this->at(j) - this->at(j + half);
this->at(j) = A;
this->at(j + half) = B * conj(rt);
}
}
});
}
for(size_t k = 0; k < n; k += flen) {
set(k, get<vpoint>(k) /= (ftype)n);
point ni = point(int(threshold)) / point(int(n));
for(auto &it: *this) {
it *= ni;
}
}
void fft() {
size_t n = size();
for(size_t i = n / 2; i >= 1; i /= 2) {
for(size_t j = 0; j < n; j += 2 * i) {
auto butterfly = [&]<class pt>(size_t k, pt rt) {
k += j;
auto A = get<pt>(k) + get<pt>(k + i);
auto B = get<pt>(k) - get<pt>(k + i);
set(k, A);
set(k + i, B * rt);
};
if(i < flen) {
exec_on_roots<point>(i, i, butterfly);
} else {
exec_on_roots<vpoint>(i, i, butterfly);
size_t n = this->size();
for(size_t half = n / 2; half >= threshold; half /= 2) {
exec_on_evals(n / (2 * half), [&](size_t k, point rt) {
k *= 2 * half;
for(size_t j = k; j < k + half; j++) {
auto t = this->at(j + half) * rt;
this->at(j + half) = this->at(j) - t;
this->at(j) += t;
}
}
});
}
}
};
const cvector cvector::roots = []() {
cvector res(pre_roots);
for(size_t n = 1; n < res.size(); n *= 2) {
auto propagate = [&](size_t k, auto rt) {
res.set(n + k, rt);
};
if(n < flen) {
res.exec_on_roots<point, true>(n, n, propagate);
} else {
res.exec_on_roots<vpoint, true>(n, n, propagate);
}
}
return res;
}();
std::array<point, ftvec::pre_roots> ftvec::roots = {};
std::array<point, ftvec::pre_roots> ftvec::evalp = {};
std::array<size_t, ftvec::pre_roots> ftvec::eval_args = {};
}
#endif // CP_ALGO_MATH_CVECTOR_HPP
58 changes: 29 additions & 29 deletions cp-algo/math/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
#define CP_ALGO_MATH_FFT_HPP
#include "../number_theory/modint.hpp"
#include "cvector.hpp"
#include <ranges>
namespace cp_algo::math::fft {
template<typename base>
struct dft {
cvector A;
ftvec A;

dft(std::vector<base> const& a, size_t n): A(n) {
for(size_t i = 0; i < std::min(n, a.size()); i++) {
A.set(i, a[i]);
A[i] = a[i];
}
if(n) {
A.fft();
Expand All @@ -27,30 +26,30 @@ namespace cp_algo::math::fft {
A.ifft();
std::vector<base> res(n);
for(size_t k = 0; k < n; k++) {
res[k] = A.get(k);
res[k] = A[k];
}
return res;
}

auto operator * (dft const& B) const {
return dft(*this) *= B;
}

point operator [](int i) const {return A.get(i);}
};

template<modint_type base>
struct dft<base> {
int split;
cvector A, B;
ftvec A, B;

dft(auto const& a, size_t n): A(n), B(n) {
split = int(std::sqrt(base::mod()));
cvector::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
n = size(A);
split = int(std::sqrt(base::mod())) + 1;
ftvec::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
size_t ti = std::min(i, i - n);
A.set(ti, A.get(ti) + ftype(a[i].rem() % split) * rt);
B.set(ti, B.get(ti) + ftype(a[i].rem() / split) * rt);

auto rem = std::remainder(a[i].rem(), split);
auto quo = (a[i].rem() - rem) / split;
A[ti] += rem * rt;
B[ti] += quo * rt;
});
if(n) {
A.fft();
Expand All @@ -65,21 +64,26 @@ namespace cp_algo::math::fft {
res = {};
return;
}
for(size_t i = 0; i < n; i += flen) {
auto tmp = A.vget(i) * D.vget(i) + B.vget(i) * C.vget(i);
A.set(i, A.vget(i) * C.vget(i));
B.set(i, B.vget(i) * D.vget(i));
C.set(i, tmp);
for(size_t i = 0; i < n; i += ftvec::threshold) {
auto AC = ftvec::dot_block(i, A, C);
auto AD = ftvec::dot_block(i, A, D);
auto BC = ftvec::dot_block(i, B, C);
auto BD = ftvec::dot_block(i, B, D);
for(size_t j = 0; j < ftvec::threshold; j++) {
A[i + j] = AC[j];
C[i + j] = AD[j] + BC[j];
B[i + j] = BD[j];
}
}
A.ifft();
B.ifft();
C.ifft();
auto splitsplit = (base(split) * split).rem();
cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
ftvec::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
rt = conj(rt);
auto Ai = A.get(i) * rt;
auto Bi = B.get(i) * rt;
auto Ci = C.get(i) * rt;
auto Ai = A[i] * rt;
auto Bi = B[i] * rt;
auto Ci = C[i] * rt;
int64_t A0 = llround(real(Ai));
int64_t A1 = llround(real(Ci));
int64_t A2 = llround(real(Bi));
Expand All @@ -97,7 +101,7 @@ namespace cp_algo::math::fft {
mul(B.A, B.B, res, k);
}
void mul(auto const& B, auto& res, size_t k) {
mul(cvector(B.A), B.B, res, k);
mul(ftvec(B.A), B.B, res, k);
}
std::vector<base> operator *= (dft &B) {
std::vector<base> res(2 * A.size());
Expand All @@ -112,8 +116,6 @@ namespace cp_algo::math::fft {
auto operator * (dft const& B) const {
return dft(*this) *= B;
}

point operator [](int i) const {return A.get(i);}
};

void mul_slow(auto &a, auto const& b, size_t k) {
Expand All @@ -135,17 +137,15 @@ namespace cp_algo::math::fft {
if(!as || !bs) {
return 0;
}
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
return std::bit_ceil(as + bs - 1) / 2;
}
void mul_truncate(auto &a, auto const& b, size_t k) {
using base = std::decay_t<decltype(a[0])>;
if(std::min({k, size(a), size(b)}) < magic) {
if(std::min({k, size(a), size(b)}) < 1) {
mul_slow(a, b, k);
return;
}
auto n = std::max(flen, std::bit_ceil(
std::min(k, size(a)) + std::min(k, size(b)) - 1
) / 2);
auto n = std::bit_ceil(std::min(k, size(a)) + std::min(k, size(b)) - 1) / 2;
a.resize(k);
auto A = dft<base>(a, n);
if(&a == &b) {
Expand Down
Loading

0 comments on commit 2868c3a

Please sign in to comment.