From 39d1754f693bf7b1daea7584c630f951507b6422 Mon Sep 17 00:00:00 2001 From: Ryan Burn Date: Fri, 13 Oct 2023 13:17:06 -0700 Subject: [PATCH] feat: support min/max chunk sizes (PROOF-642) (#29) support min/max chunk sizes --- sxt/base/iterator/index_range.cc | 37 +++++++++++++++++++++- sxt/base/iterator/index_range.h | 12 +++++++ sxt/base/iterator/index_range_utility.cc | 2 ++ sxt/base/iterator/index_range_utility.t.cc | 16 ++++++++++ sxt/multiexp/curve/multiexponentiation.h | 28 +++++++++++++--- 5 files changed, 89 insertions(+), 6 deletions(-) diff --git a/sxt/base/iterator/index_range.cc b/sxt/base/iterator/index_range.cc index daa08a9e..1f0fff9e 100644 --- a/sxt/base/iterator/index_range.cc +++ b/sxt/base/iterator/index_range.cc @@ -22,5 +22,40 @@ namespace sxt::basit { //-------------------------------------------------------------------------------------------------- // constructor //-------------------------------------------------------------------------------------------------- -index_range::index_range(size_t a, size_t b) noexcept : a_{a}, b_{b} { SXT_DEBUG_ASSERT(a <= b); } +index_range::index_range(size_t a, size_t b) noexcept + : index_range{a, b, 1, std::numeric_limits::max()} {} + +index_range::index_range(size_t a, size_t b, size_t min_chunk_size, size_t max_chunk_size) noexcept + : a_{a}, b_{b}, min_chunk_size_{min_chunk_size}, max_chunk_size_{max_chunk_size} { + SXT_DEBUG_ASSERT( + // clang-format off + 0 <= a && a <= b && + 0 < min_chunk_size_ && min_chunk_size_ <= max_chunk_size_ + // clang-format on + ); +} + +//-------------------------------------------------------------------------------------------------- +// min_chunk_size +//-------------------------------------------------------------------------------------------------- +index_range index_range::min_chunk_size(size_t val) const noexcept { + return { + a_, + b_, + val, + max_chunk_size_, + }; +} + +//-------------------------------------------------------------------------------------------------- +// max_chunk_size +//-------------------------------------------------------------------------------------------------- +index_range index_range::max_chunk_size(size_t val) const noexcept { + return { + a_, + b_, + min_chunk_size_, + val, + }; +} } // namespace sxt::basit diff --git a/sxt/base/iterator/index_range.h b/sxt/base/iterator/index_range.h index de9cbc3e..7337fefb 100644 --- a/sxt/base/iterator/index_range.h +++ b/sxt/base/iterator/index_range.h @@ -17,6 +17,7 @@ #pragma once #include +#include namespace sxt::basit { //-------------------------------------------------------------------------------------------------- @@ -28,6 +29,8 @@ class index_range { index_range(size_t a, size_t b) noexcept; + index_range(size_t a, size_t b, size_t min_chunk_size, size_t max_chunk_size) noexcept; + size_t a() const noexcept { return a_; } size_t b() const noexcept { return b_; } @@ -35,8 +38,17 @@ class index_range { bool operator==(const index_range&) const noexcept = default; + size_t min_chunk_size() const noexcept { return min_chunk_size_; } + size_t max_chunk_size() const noexcept { return max_chunk_size_; } + + [[nodiscard]] index_range min_chunk_size(size_t val) const noexcept; + + [[nodiscard]] index_range max_chunk_size(size_t val) const noexcept; + private: size_t a_{0}; size_t b_{0}; + size_t min_chunk_size_{1}; + size_t max_chunk_size_{std::numeric_limits::max()}; }; } // namespace sxt::basit diff --git a/sxt/base/iterator/index_range_utility.cc b/sxt/base/iterator/index_range_utility.cc index c2378a32..ec7b4994 100644 --- a/sxt/base/iterator/index_range_utility.cc +++ b/sxt/base/iterator/index_range_utility.cc @@ -32,6 +32,8 @@ std::pair split(const index_range& r SXT_DEBUG_ASSERT(n > 0); auto delta = rng.b() - rng.a(); auto step = std::max(basn::divide_up(delta, n), size_t{1}); + step = std::max(step, rng.min_chunk_size()); + step = std::min(step, rng.max_chunk_size()); index_range_iterator first{index_range{rng.a(), rng.b()}, step}; index_range_iterator last{index_range{rng.b(), rng.b()}, step}; return {first, last}; diff --git a/sxt/base/iterator/index_range_utility.t.cc b/sxt/base/iterator/index_range_utility.t.cc index f45b40d4..1627e342 100644 --- a/sxt/base/iterator/index_range_utility.t.cc +++ b/sxt/base/iterator/index_range_utility.t.cc @@ -60,4 +60,20 @@ TEST_CASE("we can split an index_range") { REQUIRE(*iter++ == index_range{2, 3}); REQUIRE(iter == last); } + + SECTION("we respect the min chunk size") { + auto [iter, last] = split(index_range{0, 4}.min_chunk_size(2), 4); + REQUIRE(std::distance(iter, last) == 2); + REQUIRE(*iter++ == index_range{0, 2}); + REQUIRE(*iter++ == index_range{2, 4}); + REQUIRE(iter == last); + } + + SECTION("we respect the max chunk size") { + auto [iter, last] = split(index_range{0, 4}.max_chunk_size(2), 1); + REQUIRE(std::distance(iter, last) == 2); + REQUIRE(*iter++ == index_range{0, 2}); + REQUIRE(*iter++ == index_range{2, 4}); + REQUIRE(iter == last); + } } diff --git a/sxt/multiexp/curve/multiexponentiation.h b/sxt/multiexp/curve/multiexponentiation.h index 4256a570..17e46ff2 100644 --- a/sxt/multiexp/curve/multiexponentiation.h +++ b/sxt/multiexp/curve/multiexponentiation.h @@ -28,6 +28,7 @@ #include "sxt/base/device/memory_utility.h" #include "sxt/base/device/stream.h" #include "sxt/base/iterator/index_range.h" +#include "sxt/base/num/divide_up.h" #include "sxt/execution/async/coroutine.h" #include "sxt/execution/async/future.h" #include "sxt/execution/device/device_viewable.h" @@ -152,11 +153,28 @@ async_compute_multiexponentiation(basct::cspan generators, or_alls.emplace_back(1, exponent_sequence.element_nbytes); } std::vector> products(num_outputs); - co_await xendv::concurrent_for_each(basit::index_range{0, generators.size()}, - [&](const basit::index_range& rng) noexcept { - return async_compute_multiexponentiation_partial( - or_alls, products, generators, exponents, rng); - }); + + // Pick some reasonable values for min and max chunk size so that + // we don't run out of GPU memory or split computations that are + // too small. + // + // Note: These haven't been informed by much benchmarking. I'm + // sure there are better values. This is just putting in some + // ballpark estimates to get started. + size_t min_chunk_size = 1ull << 10u; + size_t max_chunk_size = 1ull << 20u; + if (num_outputs > 0) { + max_chunk_size = basn::divide_up(max_chunk_size, num_outputs); + min_chunk_size *= num_outputs; + min_chunk_size = std::min(max_chunk_size, min_chunk_size); + } + auto rng = basit::index_range{0, generators.size()} + .min_chunk_size(min_chunk_size) + .max_chunk_size(max_chunk_size); + co_await xendv::concurrent_for_each(rng, [&](const basit::index_range& rng) noexcept { + return async_compute_multiexponentiation_partial(or_alls, products, generators, + exponents, rng); + }); memmg::managed_array res(num_outputs); for (size_t i = 0; i < num_outputs; ++i) { combine_multiproducts({&res[i], 1}, or_alls[i], products[i]);