Skip to content

Commit

Permalink
rng: use process-global generators
Browse files Browse the repository at this point in the history
  • Loading branch information
James committed Jul 10, 2024
1 parent 7a62436 commit 2a6269c
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions src/include/rng.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,51 @@
#include <random>
#include <chrono>
#include <functional>
#include <memory>
#include "mpi.h"

namespace rng
{

namespace detail
{
using Generator = std::default_random_engine;
using ProcessGeneratorPtr = std::shared_ptr<Generator>;
ProcessGeneratorPtr process_generator;

int get_process_seed()
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return 10000 * (rank + 1);
}

ProcessGeneratorPtr get_process_generator()
{
if (!process_generator) {
process_generator = std::make_shared<Generator>(get_process_seed());
}
return process_generator;
}

} // namespace detail

// ======================================================================
// Uniform

template <typename Real>
struct Uniform
{
Uniform(Real min, Real max)
: dist(min, max),
gen(std::chrono::system_clock::now().time_since_epoch().count())
: dist(min, max), gen(detail::get_process_generator())
{}

Uniform() : Uniform(0, 1) {}

Real get() { return dist(gen); }
Real get() { return dist(*gen); }

private:
std::default_random_engine gen;
detail::ProcessGeneratorPtr gen;
std::uniform_real_distribution<Real> dist;
};

Expand All @@ -37,21 +61,20 @@ template <typename Real>
struct Normal
{
Normal(Real mean, Real stdev)
: dist(mean, stdev),
gen(std::chrono::system_clock::now().time_since_epoch().count())
: dist(mean, stdev), gen(detail::get_process_generator())
{}
Normal() : Normal(0, 1) {}

Real get() { return dist(gen); }
Real get() { return dist(*gen); }
// FIXME remove me, or make standalone func
Real get(Real mean, Real stdev)
{
// should be possible to pass params to existing dist
return std::normal_distribution<Real>(mean, stdev)(gen);
return std::normal_distribution<Real>(mean, stdev)(*gen);
}

private:
std::default_random_engine gen;
detail::ProcessGeneratorPtr gen;
std::normal_distribution<Real> dist;
};

Expand Down

0 comments on commit 2a6269c

Please sign in to comment.