diff --git a/array.h b/array.h index 01e440d..1a4b837 100644 --- a/array.h +++ b/array.h @@ -158,15 +158,16 @@ GenericArray GenericArray::randomfUnique( size_t retries = (size + 10) * log(size + 10) * 2; while (result.size() != size) { - if (--retries == 0) { - ensure(false, "There are not enough unique elements"); - } - T t = func(args...); if (!set.count(t)) { set.insert(t); result.push_back(t); } + + if (--retries == 0) { + ensure(false, "There are not enough unique elements"); + } + } return result; @@ -386,6 +387,7 @@ template using TArray = jngen::GenericArray; using Array = jngen::GenericArray; +using Array2d = jngen::GenericArray>; using Array64 = jngen::GenericArray; using Arrayf = jngen::GenericArray; using Arrayp = jngen::GenericArray>; diff --git a/jngen.h b/jngen.h index 91702ab..428b7f9 100644 --- a/jngen.h +++ b/jngen.h @@ -1500,15 +1500,16 @@ GenericArray GenericArray::randomfUnique( size_t retries = (size + 10) * log(size + 10) * 2; while (result.size() != size) { - if (--retries == 0) { - ensure(false, "There are not enough unique elements"); - } - T t = func(args...); if (!set.count(t)) { set.insert(t); result.push_back(t); } + + if (--retries == 0) { + ensure(false, "There are not enough unique elements"); + } + } return result; @@ -1728,6 +1729,7 @@ template using TArray = jngen::GenericArray; using Array = jngen::GenericArray; +using Array2d = jngen::GenericArray>; using Array64 = jngen::GenericArray; using Arrayf = jngen::GenericArray; using Arrayp = jngen::GenericArray>; @@ -1743,6 +1745,231 @@ jngen::GenericArray makeArray(const std::initializer_list& values) { } +#include +#include +#include +#include +#include +#include + +namespace jngen { + +namespace detail { + +int multiply(int x, int y, int mod) { + return static_cast(x) * y % mod; +} + +long long multiply(long long x, long long y, long long mod) { +#if defined(__SIZEOF_INT128__) + return static_cast<__int128>(x) * y % mod; +#else + long long res = 0; + while (y) { + if (y&1) { + res = (static_cast(res) + x) % mod; + } + x = (static_cast(x) + x) % mod; + y >>= 1; + } + return res; +#endif +} + +int power(int x, int k, int mod) { + int res = 1; + while (k) { + if (k&1) { + res = multiply(res, x, mod); + } + x = multiply(x, x, mod); + k >>= 1; + } + return res; +} + +long long power(long long x, long long k, long long mod) { + long long res = 1; + while (k) { + if (k&1) { + res = multiply(res, x, mod); + } + x = multiply(x, x, mod); + k >>= 1; + } + return res; +} + +template +bool millerRabinTest(I n, const std::vector& witnesses) { + static_assert( + std::is_same::value || std::is_same::value, + "millerRabinTest only is supported"); + + if (n == 1) { + return false; + } + + constexpr int LIMIT = 10000; + + if (n <= LIMIT) { + for (int i = 2; i*i <= n; ++i) { + if (n%i == 0) { + return false; + } + } + return true; + } + + int r = 0; + I d = n - 1; + while (d % 2 == 0) { + ++r; + d /= 2; + } + + for (I a: witnesses) { + I x = power(a, d, n); + if (x == 1 || x == n - 1) { + continue; + } + + bool composite = true; + for (int i = 0; i < r - 1; ++i) { + x = multiply(x, x, n); + if (x == 1) { + return false; + } + if (x == n - 1) { + i = r; + composite = false; + continue; + } + } + if (composite) { + return false; + } + } + return true; +} + +} // namespace detail + +bool isPrime(long long n) { + const static std::vector INT_WITNESSES{2, 7, 61}; + const static std::vector LONG_LONG_WITNESSES + {2, 3, 5, 7, 11, 13, 17, 19, 23}; + + if (n < std::numeric_limits::max()) { + return detail::millerRabinTest(n, INT_WITNESSES); + } else { + return detail::millerRabinTest(n, LONG_LONG_WITNESSES); + } +} + +class MathRandom { +public: + MathRandom() { + static bool created = false; + ensure(!created, "jngen::MathRandom should be created only once"); + created = true; + } + + static long long randomPrime(long long n) { + return randomPrime(2, n - 1); + } + + static long long randomPrime(long long l, long long r) { + ensure(l <= r); + int retries = std::log(l) * std::log(l); + while (true) { + long long x = rnd.next(l, r); + if (isPrime(x)) { + return x; + } + + if (--retries == 0) { + ensure( + false, + format( + "There are no primes between %lld and %lld", + l, r) + ); + } + } + } + + static Array partition(int n, size_t numParts) { + auto res = partition(static_cast(n), numParts); + return Array(res.begin(), res.end()); + } + + static Array64 partition(long long n, size_t numParts) { + auto res = partitionNonEmpty( + static_cast(n + numParts), numParts); + for (auto& x: res) { + --x; + } + return res; + } + + static Array partitionNonEmpty(int n, size_t numParts) { + auto res = partitionNonEmpty(static_cast(n), numParts); + return Array(res.begin(), res.end()); + } + + static Array64 partitionNonEmpty(long long n, size_t numParts) { + ensure(static_cast(numParts) <= n); + auto delimiters = Array64::randomUnique(numParts - 1, 1, n - 1).sorted(); + delimiters.insert(delimiters.begin(), 0); + delimiters.push_back(n); + Array64 res(numParts); + for (size_t i = 0; i < numParts; ++i) { + res[i] = delimiters[i + 1] - delimiters[i]; + } + return res; + } + + template + TArray> partition(TArray elements, size_t numParts) { + return partition( + std::move(elements), + partition(static_cast(elements.size()), numParts)); + } + + template + TArray> partitionNonEmpty(TArray elements, size_t numParts) { + return partition( + std::move(elements), + partitionNonEmpty(static_cast(elements.size()), numParts)); + } + + template + TArray> partition(TArray elements, const Array& sizes) { + elements.shuffle(); + TArray> res; + auto it = elements.begin(); + for (int size: sizes) { + res.emplace_back(); + std::copy(it, it + size, std::back_inserter(res.back())); + it += size; + } + + ensure(it == elements.end(), "sum(sizes) != elements.size()"); + + return res; + } +}; + +MathRandom rndm; + +} // namespace jngen + +using jngen::isPrime; + +using jngen::rndm; + + namespace jngen { class ArrayRandom { diff --git a/math.h b/math.h new file mode 100644 index 0000000..cb2e75f --- /dev/null +++ b/math.h @@ -0,0 +1,229 @@ +#pragma once + +#include "array.h" +#include "common.h" +#include "random.h" + +#include +#include +#include +#include +#include +#include + +namespace jngen { + +namespace detail { + +int multiply(int x, int y, int mod) { + return static_cast(x) * y % mod; +} + +long long multiply(long long x, long long y, long long mod) { +#if defined(__SIZEOF_INT128__) + return static_cast<__int128>(x) * y % mod; +#else + long long res = 0; + while (y) { + if (y&1) { + res = (static_cast(res) + x) % mod; + } + x = (static_cast(x) + x) % mod; + y >>= 1; + } + return res; +#endif +} + +int power(int x, int k, int mod) { + int res = 1; + while (k) { + if (k&1) { + res = multiply(res, x, mod); + } + x = multiply(x, x, mod); + k >>= 1; + } + return res; +} + +long long power(long long x, long long k, long long mod) { + long long res = 1; + while (k) { + if (k&1) { + res = multiply(res, x, mod); + } + x = multiply(x, x, mod); + k >>= 1; + } + return res; +} + +template +bool millerRabinTest(I n, const std::vector& witnesses) { + static_assert( + std::is_same::value || std::is_same::value, + "millerRabinTest only is supported"); + + if (n == 1) { + return false; + } + + constexpr int LIMIT = 10000; + + if (n <= LIMIT) { + for (int i = 2; i*i <= n; ++i) { + if (n%i == 0) { + return false; + } + } + return true; + } + + int r = 0; + I d = n - 1; + while (d % 2 == 0) { + ++r; + d /= 2; + } + + for (I a: witnesses) { + I x = power(a, d, n); + if (x == 1 || x == n - 1) { + continue; + } + + bool composite = true; + for (int i = 0; i < r - 1; ++i) { + x = multiply(x, x, n); + if (x == 1) { + return false; + } + if (x == n - 1) { + i = r; + composite = false; + continue; + } + } + if (composite) { + return false; + } + } + return true; +} + +} // namespace detail + +bool isPrime(long long n) { + const static std::vector INT_WITNESSES{2, 7, 61}; + const static std::vector LONG_LONG_WITNESSES + {2, 3, 5, 7, 11, 13, 17, 19, 23}; + + if (n < std::numeric_limits::max()) { + return detail::millerRabinTest(n, INT_WITNESSES); + } else { + return detail::millerRabinTest(n, LONG_LONG_WITNESSES); + } +} + +class MathRandom { +public: + MathRandom() { + static bool created = false; + ensure(!created, "jngen::MathRandom should be created only once"); + created = true; + } + + static long long randomPrime(long long n) { + return randomPrime(2, n - 1); + } + + static long long randomPrime(long long l, long long r) { + ensure(l <= r); + int retries = std::log(l) * std::log(l); + while (true) { + long long x = rnd.next(l, r); + if (isPrime(x)) { + return x; + } + + if (--retries == 0) { + ensure( + false, + format( + "There are no primes between %lld and %lld", + l, r) + ); + } + } + } + + static Array partition(int n, size_t numParts) { + auto res = partition(static_cast(n), numParts); + return Array(res.begin(), res.end()); + } + + static Array64 partition(long long n, size_t numParts) { + auto res = partitionNonEmpty( + static_cast(n + numParts), numParts); + for (auto& x: res) { + --x; + } + return res; + } + + static Array partitionNonEmpty(int n, size_t numParts) { + auto res = partitionNonEmpty(static_cast(n), numParts); + return Array(res.begin(), res.end()); + } + + static Array64 partitionNonEmpty(long long n, size_t numParts) { + ensure(static_cast(numParts) <= n); + auto delimiters = Array64::randomUnique(numParts - 1, 1, n - 1).sorted(); + delimiters.insert(delimiters.begin(), 0); + delimiters.push_back(n); + Array64 res(numParts); + for (size_t i = 0; i < numParts; ++i) { + res[i] = delimiters[i + 1] - delimiters[i]; + } + return res; + } + + template + TArray> partition(TArray elements, size_t numParts) { + return partition( + std::move(elements), + partition(static_cast(elements.size()), numParts)); + } + + template + TArray> partitionNonEmpty(TArray elements, size_t numParts) { + return partition( + std::move(elements), + partitionNonEmpty(static_cast(elements.size()), numParts)); + } + + template + TArray> partition(TArray elements, const Array& sizes) { + elements.shuffle(); + TArray> res; + auto it = elements.begin(); + for (int size: sizes) { + res.emplace_back(); + std::copy(it, it + size, std::back_inserter(res.back())); + it += size; + } + + ensure(it == elements.end(), "sum(sizes) != elements.size()"); + + return res; + } +}; + +MathRandom rndm; + +} // namespace jngen + +using jngen::isPrime; + +using jngen::rndm;