diff --git a/cp-algo/algebra/modint.hpp b/cp-algo/algebra/modint.hpp index f9928bb..ce9195f 100644 --- a/cp-algo/algebra/modint.hpp +++ b/cp-algo/algebra/modint.hpp @@ -5,7 +5,7 @@ namespace cp_algo::algebra { template struct modint_base { - static int mod() { + static int64_t mod() { return modint::mod(); } modint_base(): r(0) {} @@ -20,7 +20,11 @@ namespace cp_algo::algebra { return to_modint() *= t.inv(); } modint& operator *= (const modint &t) { - r *= t.r; if(mod()) {r %= mod();} + if(mod() <= uint32_t(-1)) { + r = r * t.r % mod(); + } else { + r = __int128(r) * t.r % mod(); + } return to_modint(); } modint& operator += (const modint &t) { @@ -36,11 +40,10 @@ namespace cp_algo::algebra { modint operator * (const modint &t) const {return modint(to_modint()) *= t;} modint operator / (const modint &t) const {return modint(to_modint()) /= t;} auto operator <=> (const modint_base &t) const = default; - explicit operator int() const {return r;} int64_t rem() const {return 2 * r > (uint64_t)mod() ? r - mod() : r;} // Only use if you really know what you're doing! - uint64_t modmod() const {return 8LL * mod() * mod();}; + uint64_t modmod() const {return 8ULL * mod() * mod();}; void add_unsafe(uint64_t t) {r += t;} void pseudonormalize() {r = std::min(r, r - modmod());} modint const& normalize() { @@ -65,21 +68,30 @@ namespace cp_algo::algebra { return out << x.getr(); } - template + template struct modint: modint_base> { - static constexpr int mod() {return m;} + static constexpr int64_t mod() {return m;} using Base = modint_base>; using Base::Base; }; struct dynamic_modint: modint_base { - static int mod() {return m;} - static void switch_mod(int nm) {m = nm;} + static int64_t mod() {return m;} + static void switch_mod(int64_t nm) {m = nm;} using Base = modint_base; using Base::Base; + + // Wrapper for temp switching + auto static with_switched_mod(int64_t tmp, auto callback) { + auto prev = mod(); + switch_mod(tmp); + auto res = callback(); + switch_mod(prev); + return res; + } private: - static int m; + static int64_t m; }; - int dynamic_modint::m = 0; + int64_t dynamic_modint::m = 0; } #endif // CP_ALGO_ALGEBRA_MODINT_HPP diff --git a/cp-algo/algebra/number_theory.hpp b/cp-algo/algebra/number_theory.hpp index e5f27b3..3e5ce0f 100644 --- a/cp-algo/algebra/number_theory.hpp +++ b/cp-algo/algebra/number_theory.hpp @@ -27,5 +27,44 @@ namespace cp_algo::algebra { } } } + + template + requires(std::is_base_of_v, base>) + bool is_prime_mod() { + auto m = base::mod(); + if(m == 1 || m % 2 == 0) { + return m == 2; + } + auto m1 = m - 1; + int d = 0; + while(m1 % 2 == 0) { + m1 /= 2; + d++; + } + auto test = [&](auto x) { + x = bpow(x, m1); + if(x == 0 || x == 1 || x == -1) { + return true; + } + for(int i = 0; i <= d; i++) { + if(x == -1) { + return true; + } + x *= x; + } + return false; + }; + for(base b: {2, 325, 9375, 28178, 450775, 9780504, 1795265022}) { + if(!test(b)) { + return false; + } + } + return true; + } + bool is_prime(int64_t m) { + return dynamic_modint::with_switched_mod(m, [](){ + return is_prime_mod(); + }); + } } #endif // CP_ALGO_ALGEBRA_NUMBER_THEORY_HPP diff --git a/verify/number_theory/primality.test.cpp b/verify/number_theory/primality.test.cpp new file mode 100644 index 0000000..3a76db6 --- /dev/null +++ b/verify/number_theory/primality.test.cpp @@ -0,0 +1,26 @@ +// @brief Primality Test +#define PROBLEM "https://judge.yosupo.jp/problem/primality_test" +#pragma GCC optimize("Ofast,unroll-loops") +#pragma GCC target("avx2,tune=native") +#include "cp-algo/algebra/number_theory.hpp" +#include + +using namespace std; +using namespace cp_algo::algebra; + +void solve() { + int64_t m; + cin >> m; + cout << (is_prime(m) ? "Yes" : "No") << "\n"; +} + +signed main() { + //freopen("input.txt", "r", stdin); + ios::sync_with_stdio(0); + cin.tie(0); + int t = 1; + cin >> t; + while(t--) { + solve(); + } +}