From d91503fa5d7853eed70d6925dba35ec40a7ea96a Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Sun, 28 Jan 2018 18:55:53 +0300 Subject: [PATCH] interleave, nextByDistribution --- jngen.h | 51 +++++++++++++++++++++++++++++++++++++++++++++++++ random.h | 15 +++++++++++++++ sequence_ops.h | 36 ++++++++++++++++++++++++++++++++++ tests/array.cpp | 10 ++++++++++ 4 files changed, 112 insertions(+) diff --git a/jngen.h b/jngen.h index da18f4b..37b02da 100644 --- a/jngen.h +++ b/jngen.h @@ -1725,6 +1725,21 @@ class Random { return choice(container.begin(), container.end()); } + template + size_t nextByDistribution(const std::vector& distribution) { + ensure(!distribution.empty(), "Cannot sample by empty distribution"); + Numeric sum = std::accumulate( + distribution.begin(), distribution.end(), Numeric(0)); + auto x = next(sum); + for (size_t i = 0; i < distribution.size(); ++i) { + if (x < distribution[i]) { + return i; + } + x -= distribution[i]; + } + return distribution.size() - 1; + } + private: template T smallWnext(int w, Args... args) { @@ -3053,6 +3068,7 @@ using namespace jngen::namespace_for_fake_operator_ltlt; #include #include +#include namespace jngen { @@ -3084,10 +3100,45 @@ T choice(std::initializer_list ilist) { return choice(ilist.begin(), ilist.end()); } +namespace detail { + +template +typename Collection2D::value_type interleave(const Collection2D& collection) { + std::vector sizes; + for (const auto& c: collection) { + sizes.push_back(c.size()); + } + size_t size = std::accumulate(sizes.begin(), sizes.end(), 0u); + + typename Collection2D::value_type result; + while (size > 0) { + size_t id = rnd.nextByDistribution(sizes); + result.emplace_back(collection[id][collection[id].size() - sizes[id]]); + --sizes[id]; + + --size; + } + + return result; +} + +} // namespace detail + +template +typename Collection2D::value_type interleave(const Collection2D& collection) { + return detail::interleave(collection); +} + +template +Collection interleave(const std::initializer_list& ilist) { + return detail::interleave>(ilist); +} + } // namespace jngen using jngen::shuffle; using jngen::choice; +using jngen::interleave; #include diff --git a/random.h b/random.h index 71208ef..a712253 100644 --- a/random.h +++ b/random.h @@ -121,6 +121,21 @@ class Random { return choice(container.begin(), container.end()); } + template + size_t nextByDistribution(const std::vector& distribution) { + ensure(!distribution.empty(), "Cannot sample by empty distribution"); + Numeric sum = std::accumulate( + distribution.begin(), distribution.end(), Numeric(0)); + auto x = next(sum); + for (size_t i = 0; i < distribution.size(); ++i) { + if (x < distribution[i]) { + return i; + } + x -= distribution[i]; + } + return distribution.size() - 1; + } + private: template T smallWnext(int w, Args... args) { diff --git a/sequence_ops.h b/sequence_ops.h index e1c7218..8e4e874 100644 --- a/sequence_ops.h +++ b/sequence_ops.h @@ -5,6 +5,7 @@ #include #include +#include namespace jngen { @@ -36,7 +37,42 @@ T choice(std::initializer_list ilist) { return choice(ilist.begin(), ilist.end()); } +namespace detail { + +template +typename Collection2D::value_type interleave(const Collection2D& collection) { + std::vector sizes; + for (const auto& c: collection) { + sizes.push_back(c.size()); + } + size_t size = std::accumulate(sizes.begin(), sizes.end(), 0u); + + typename Collection2D::value_type result; + while (size > 0) { + size_t id = rnd.nextByDistribution(sizes); + result.emplace_back(collection[id][collection[id].size() - sizes[id]]); + --sizes[id]; + + --size; + } + + return result; +} + +} // namespace detail + +template +typename Collection2D::value_type interleave(const Collection2D& collection) { + return detail::interleave(collection); +} + +template +Collection interleave(const std::initializer_list& ilist) { + return detail::interleave>(ilist); +} + } // namespace jngen using jngen::shuffle; using jngen::choice; +using jngen::interleave; diff --git a/tests/array.cpp b/tests/array.cpp index 057b05b..0371aad 100644 --- a/tests/array.cpp +++ b/tests/array.cpp @@ -170,4 +170,14 @@ BOOST_AUTO_TEST_CASE(print_matrix) { BOOST_TEST(out.str() == "0 0\n"); } +BOOST_AUTO_TEST_CASE(interleave) { + auto a = Array::id(3, 1); + auto b = Array::id(3, 11); + auto expected = Array{11, 1, 2, 12, 13, 3}; + + rnd.seed(10); + + BOOST_TEST(jngen::interleave({a, b}) == expected); +} + BOOST_AUTO_TEST_SUITE_END()