Skip to content

Commit

Permalink
feat: support min/max chunk sizes (PROOF-642) (#29)
Browse files Browse the repository at this point in the history
support min/max chunk sizes
  • Loading branch information
rnburn authored Oct 13, 2023
1 parent 29c294a commit 39d1754
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 6 deletions.
37 changes: 36 additions & 1 deletion sxt/base/iterator/index_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,40 @@ namespace sxt::basit {
//--------------------------------------------------------------------------------------------------
// constructor
//--------------------------------------------------------------------------------------------------
index_range::index_range(size_t a, size_t b) noexcept : a_{a}, b_{b} { SXT_DEBUG_ASSERT(a <= b); }
index_range::index_range(size_t a, size_t b) noexcept
: index_range{a, b, 1, std::numeric_limits<size_t>::max()} {}

index_range::index_range(size_t a, size_t b, size_t min_chunk_size, size_t max_chunk_size) noexcept
: a_{a}, b_{b}, min_chunk_size_{min_chunk_size}, max_chunk_size_{max_chunk_size} {
SXT_DEBUG_ASSERT(
// clang-format off
0 <= a && a <= b &&
0 < min_chunk_size_ && min_chunk_size_ <= max_chunk_size_
// clang-format on
);
}

//--------------------------------------------------------------------------------------------------
// min_chunk_size
//--------------------------------------------------------------------------------------------------
index_range index_range::min_chunk_size(size_t val) const noexcept {
return {
a_,
b_,
val,
max_chunk_size_,
};
}

//--------------------------------------------------------------------------------------------------
// max_chunk_size
//--------------------------------------------------------------------------------------------------
index_range index_range::max_chunk_size(size_t val) const noexcept {
return {
a_,
b_,
min_chunk_size_,
val,
};
}
} // namespace sxt::basit
12 changes: 12 additions & 0 deletions sxt/base/iterator/index_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cstddef>
#include <limits>

namespace sxt::basit {
//--------------------------------------------------------------------------------------------------
Expand All @@ -28,15 +29,26 @@ class index_range {

index_range(size_t a, size_t b) noexcept;

index_range(size_t a, size_t b, size_t min_chunk_size, size_t max_chunk_size) noexcept;

size_t a() const noexcept { return a_; }
size_t b() const noexcept { return b_; }

size_t size() const noexcept { return b_ - a_; }

bool operator==(const index_range&) const noexcept = default;

size_t min_chunk_size() const noexcept { return min_chunk_size_; }
size_t max_chunk_size() const noexcept { return max_chunk_size_; }

[[nodiscard]] index_range min_chunk_size(size_t val) const noexcept;

[[nodiscard]] index_range max_chunk_size(size_t val) const noexcept;

private:
size_t a_{0};
size_t b_{0};
size_t min_chunk_size_{1};
size_t max_chunk_size_{std::numeric_limits<size_t>::max()};
};
} // namespace sxt::basit
2 changes: 2 additions & 0 deletions sxt/base/iterator/index_range_utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ std::pair<index_range_iterator, index_range_iterator> split(const index_range& r
SXT_DEBUG_ASSERT(n > 0);
auto delta = rng.b() - rng.a();
auto step = std::max(basn::divide_up(delta, n), size_t{1});
step = std::max(step, rng.min_chunk_size());
step = std::min(step, rng.max_chunk_size());
index_range_iterator first{index_range{rng.a(), rng.b()}, step};
index_range_iterator last{index_range{rng.b(), rng.b()}, step};
return {first, last};
Expand Down
16 changes: 16 additions & 0 deletions sxt/base/iterator/index_range_utility.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,20 @@ TEST_CASE("we can split an index_range") {
REQUIRE(*iter++ == index_range{2, 3});
REQUIRE(iter == last);
}

SECTION("we respect the min chunk size") {
auto [iter, last] = split(index_range{0, 4}.min_chunk_size(2), 4);
REQUIRE(std::distance(iter, last) == 2);
REQUIRE(*iter++ == index_range{0, 2});
REQUIRE(*iter++ == index_range{2, 4});
REQUIRE(iter == last);
}

SECTION("we respect the max chunk size") {
auto [iter, last] = split(index_range{0, 4}.max_chunk_size(2), 1);
REQUIRE(std::distance(iter, last) == 2);
REQUIRE(*iter++ == index_range{0, 2});
REQUIRE(*iter++ == index_range{2, 4});
REQUIRE(iter == last);
}
}
28 changes: 23 additions & 5 deletions sxt/multiexp/curve/multiexponentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/iterator/index_range.h"
#include "sxt/base/num/divide_up.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/async/future.h"
#include "sxt/execution/device/device_viewable.h"
Expand Down Expand Up @@ -152,11 +153,28 @@ async_compute_multiexponentiation(basct::cspan<Element> generators,
or_alls.emplace_back(1, exponent_sequence.element_nbytes);
}
std::vector<memmg::managed_array<Element>> products(num_outputs);
co_await xendv::concurrent_for_each(basit::index_range{0, generators.size()},
[&](const basit::index_range& rng) noexcept {
return async_compute_multiexponentiation_partial<Element>(
or_alls, products, generators, exponents, rng);
});

// Pick some reasonable values for min and max chunk size so that
// we don't run out of GPU memory or split computations that are
// too small.
//
// Note: These haven't been informed by much benchmarking. I'm
// sure there are better values. This is just putting in some
// ballpark estimates to get started.
size_t min_chunk_size = 1ull << 10u;
size_t max_chunk_size = 1ull << 20u;
if (num_outputs > 0) {
max_chunk_size = basn::divide_up(max_chunk_size, num_outputs);
min_chunk_size *= num_outputs;
min_chunk_size = std::min(max_chunk_size, min_chunk_size);
}
auto rng = basit::index_range{0, generators.size()}
.min_chunk_size(min_chunk_size)
.max_chunk_size(max_chunk_size);
co_await xendv::concurrent_for_each(rng, [&](const basit::index_range& rng) noexcept {
return async_compute_multiexponentiation_partial<Element>(or_alls, products, generators,
exponents, rng);
});
memmg::managed_array<Element> res(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
combine_multiproducts<Element>({&res[i], 1}, or_alls[i], products[i]);
Expand Down

0 comments on commit 39d1754

Please sign in to comment.