diff --git a/src/include/rng.hxx b/src/include/rng.hxx index 13e24b22e..c8381fa72 100644 --- a/src/include/rng.hxx +++ b/src/include/rng.hxx @@ -6,10 +6,35 @@ #include #include #include +#include +#include "mpi.h" namespace rng { +namespace detail +{ +using Generator = std::default_random_engine; +using ProcessGeneratorPtr = std::shared_ptr; +ProcessGeneratorPtr process_generator; + +int get_process_seed() +{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + return 10000 * (rank + 1); +} + +ProcessGeneratorPtr get_process_generator() +{ + if (!process_generator) { + process_generator = std::make_shared(get_process_seed()); + } + return process_generator; +} + +} // namespace detail + // ====================================================================== // Uniform @@ -17,16 +42,15 @@ template struct Uniform { Uniform(Real min, Real max) - : dist(min, max), - gen(std::chrono::system_clock::now().time_since_epoch().count()) + : dist(min, max), gen(detail::get_process_generator()) {} Uniform() : Uniform(0, 1) {} - Real get() { return dist(gen); } + Real get() { return dist(*gen); } private: - std::default_random_engine gen; + detail::ProcessGeneratorPtr gen; std::uniform_real_distribution dist; }; @@ -37,21 +61,20 @@ template struct Normal { Normal(Real mean, Real stdev) - : dist(mean, stdev), - gen(std::chrono::system_clock::now().time_since_epoch().count()) + : dist(mean, stdev), gen(detail::get_process_generator()) {} Normal() : Normal(0, 1) {} - Real get() { return dist(gen); } + Real get() { return dist(*gen); } // FIXME remove me, or make standalone func Real get(Real mean, Real stdev) { // should be possible to pass params to existing dist - return std::normal_distribution(mean, stdev)(gen); + return std::normal_distribution(mean, stdev)(*gen); } private: - std::default_random_engine gen; + detail::ProcessGeneratorPtr gen; std::normal_distribution dist; };