Skip to content

Commit

Permalink
restructure lib files
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksandr Kulkov committed Feb 10, 2024
1 parent 68c3965 commit 672ee8a
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 81 deletions.
File renamed without changes.
9 changes: 7 additions & 2 deletions src/algebra/common.cpp → cp-algo/algebra/common.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
namespace algebra { // common
#ifndef ALGEBRA_COMMON_HPP
#define ALGEBRA_COMMON_HPP
#include <chrono>
#include <random>
namespace algebra {
const int maxn = 1 << 20;
const int magic = 250; // threshold for sizes to run the naive algo
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
std::mt19937 rng(std::chrono::steady_clock::now().time_since_epoch().count());

auto bpow(auto x, int64_t n, auto ans) {
for(; n; n /= 2, x = x * x) {
Expand Down Expand Up @@ -57,3 +61,4 @@ namespace algebra { // common
return F[n];
}
}
#endif // ALGEBRA_COMMON_HPP
24 changes: 15 additions & 9 deletions src/algebra/fft.cpp → cp-algo/algebra/fft.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
namespace algebra { // fft
#ifndef ALGEBRA_FFT_HPP
#define ALGEBRA_FFT_HPP
#include "common.hpp"
#include "modular.hpp"
#include <vector>
namespace algebra {
namespace fft {
using ftype = double;
struct point {
Expand Down Expand Up @@ -66,7 +71,7 @@ namespace algebra { // fft
}
}

void mul_slow(vector<auto> &a, const vector<auto> &b) {
void mul_slow(std::vector<auto> &a, const std::vector<auto> &b) {
if(a.empty() || b.empty()) {
a.clear();
} else {
Expand All @@ -75,7 +80,7 @@ namespace algebra { // fft
a.resize(n + m - 1);
for(int k = n + m - 2; k >= 0; k--) {
a[k] *= b[0];
for(int j = max(k - n + 1, 1); j < min(k + 1, m); j++) {
for(int j = std::max(k - n + 1, 1); j < std::min(k + 1, m); j++) {
a[k] += a[k - j] * b[j];
}
}
Expand All @@ -85,9 +90,9 @@ namespace algebra { // fft
template<int m>
struct dft {
static constexpr int split = 1 << 15;
vector<point> A;
std::vector<point> A;

dft(vector<modular<m>> const& a, size_t n): A(n) {
dft(std::vector<modular<m>> const& a, size_t n): A(n) {
for(size_t i = 0; i < min(n, a.size()); i++) {
A[i] = point(
a[i].rem() % split,
Expand All @@ -103,9 +108,9 @@ namespace algebra { // fft
assert(A.size() == B.A.size());
size_t n = A.size();
if(!n) {
return vector<modular<m>>();
return std::vector<modular<m>>();
}
vector<point> C(n), D(n);
std::vector<point> C(n), D(n);
for(size_t i = 0; i < n; i++) {
C[i] = A[i] * (B[i] + B[(n - i) % n].conj());
D[i] = A[i] * (B[i] - B[(n - i) % n].conj());
Expand All @@ -115,7 +120,7 @@ namespace algebra { // fft
reverse(begin(C) + 1, end(C));
reverse(begin(D) + 1, end(D));
int t = 2 * n;
vector<modular<m>> res(n);
std::vector<modular<m>> res(n);
for(size_t i = 0; i < n; i++) {
modular<m> A0 = llround(C[i].real() / t);
modular<m> A1 = llround(C[i].imag() / t + D[i].imag() / t);
Expand All @@ -141,7 +146,7 @@ namespace algebra { // fft
}

template<int m>
void mul(vector<modular<m>> &a, vector<modular<m>> b) {
void mul(std::vector<modular<m>> &a, std::vector<modular<m>> b) {
if(min(a.size(), b.size()) < magic) {
mul_slow(a, b);
return;
Expand All @@ -156,3 +161,4 @@ namespace algebra { // fft
}
}
}
#endif // ALGEBRA_FFT_HPP
39 changes: 23 additions & 16 deletions src/algebra/matrix.cpp → cp-algo/algebra/matrix.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@

namespace algebra { // matrix
#ifndef ALGEBRA_MATRIX_HPP
#define ALGEBRA_MATRIX_HPP
#include "common.hpp"
#include "modular.hpp"
#include <valarray>
#include <iostream>
#include <optional>
namespace algebra {
template<int mod>
struct matrix {
using base = modular<mod>;
size_t n, m;
valarray<valarray<base>> a;
matrix(size_t n, size_t m): n(n), m(m), a(valarray<base>(m), n) {}
matrix(valarray<valarray<base>> a): n(size(a)), m(n ? size(a[0]) : 0), a(a) {}
std::valarray<std::valarray<base>> a;
matrix(size_t n, size_t m): n(n), m(m), a(std::valarray<base>(m), n) {}
matrix(std::valarray<std::valarray<base>> a): n(size(a)), m(n ? size(a[0]) : 0), a(a) {}

auto& operator[] (size_t i) {return a[i];}
auto const& operator[] (size_t i) const {return a[i];}
Expand All @@ -20,15 +26,15 @@ namespace algebra { // matrix
void read() {
for(size_t i = 0; i < n; i++) {
for(size_t j = 0; j < m; j++) {
cin >> (*this)[i][j];
std::cin >> (*this)[i][j];
}
}
}

void print() const {
for(size_t i = 0; i < n; i++) {
for(size_t j = 0; j < m; j++) {
cout << (*this)[i][j] << " \n"[j + 1 == m];
std::cout << (*this)[i][j] << " \n"[j + 1 == m];
}
}
}
Expand All @@ -46,15 +52,15 @@ namespace algebra { // matrix
assert(n == b.n);
matrix res(n, m+b.m);
for(size_t i = 0; i < n; i++) {
res[i][slice(0,m,1)] = a[i];
res[i][slice(m,b.m,1)] = b[i];
res[i][std::slice(0,m,1)] = a[i];
res[i][std::slice(m,b.m,1)] = b[i];
}
return res;
}
matrix submatrix(auto slicex, auto slicey) const {
valarray res = a[slicex];
std::valarray res = a[slicex];
for(auto &row: res) {
row = valarray(row[slicey]);
row = std::valarray(row[slicey]);
}
return res;
}
Expand Down Expand Up @@ -157,7 +163,7 @@ namespace algebra { // matrix
return res;
}

optional<matrix> inv() const {
std::optional<matrix> inv() const {
assert(n == m);
matrix b = *this | eye(n);
if(size(b.gauss<reverse>(n)[0]) < n) {
Expand All @@ -166,11 +172,11 @@ namespace algebra { // matrix
for(size_t i = 0; i < n; i++) {
b[i] *= b[i][i].inv();
}
return b.submatrix(slice(0, n, 1), slice(n, n, 1));
return b.submatrix(std::slice(0, n, 1), std::slice(n, n, 1));
}

// [solution, basis], transposed
optional<array<matrix, 2>> solve(matrix t) const {
std::optional<array<matrix, 2>> solve(matrix t) const {
assert(n == t.n);
matrix b = *this | t;
auto [pivots, free] = b.gauss<reverse>();
Expand All @@ -188,9 +194,10 @@ namespace algebra { // matrix
sols[i][free[i]] = -1;
}
return array{
sols.submatrix(slice(size(free) - t.m, t.m, 1), slice(0, m, 1)),
sols.submatrix(slice(0, size(free) - t.m, 1), slice(0, m, 1))
sols.submatrix(std::slice(size(free) - t.m, t.m, 1), std::slice(0, m, 1)),
sols.submatrix(std::slice(0, size(free) - t.m, 1), std::slice(0, m, 1))
};
}
};
}
#endif // ALGEBRA_MATRIX_HPP
14 changes: 10 additions & 4 deletions src/algebra/modular.cpp → cp-algo/algebra/modular.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
namespace algebra { // modular
#ifndef ALGEBRA_MODULAR_HPP
#define ALGEBRA_MODULAR_HPP
#include "common.hpp"
#include <iostream>
#include <optional>
namespace algebra {
template<int m>
struct modular {
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
// solves x^2 = y (mod m) assuming m is prime in O(log m).
// returns nullopt if no sol.
optional<modular> sqrt() const {
std::optional<modular> sqrt() const {
static modular y;
y = *this;
if(r == 0) {
Expand Down Expand Up @@ -61,12 +66,13 @@ namespace algebra { // modular
};

template<int m>
istream& operator >> (istream &in, modular<m> &x) {
std::istream& operator >> (std::istream &in, modular<m> &x) {
return in >> x.r;
}

template<int m>
ostream& operator << (ostream &out, modular<m> const& x) {
std::ostream& operator << (std::ostream &out, modular<m> const& x) {
return out << x.r % m;
}
}
#endif // ALGEBRA_MODULAR_HPP
Loading

0 comments on commit 672ee8a

Please sign in to comment.