Skip to content

Commit

Permalink
fix: fix overflow in variable length multiexponentiation (PROOF-922) (#…
Browse files Browse the repository at this point in the history
…201)

* fix overflow

* add tests

* tweak

* fix overflow
  • Loading branch information
rnburn authored Nov 5, 2024
1 parent 559c83e commit 8e35da9
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sxt/cbindings/backend/computational_backend_utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ basct::cspan<uint8_t> make_scalars_span(const uint8_t* data,
auto num_outputs = output_bit_table.size();
SXT_DEBUG_ASSERT(output_lengths.size() == num_outputs);

unsigned output_bit_sum = 0;
size_t output_bit_sum = 0;
unsigned n = 0;
unsigned prev_len = 0;
for (unsigned output_index = 0; output_index < num_outputs; ++output_index) {
Expand All @@ -45,7 +45,7 @@ basct::cspan<uint8_t> make_scalars_span(const uint8_t* data,
prev_len = len;
}

auto output_num_bytes = basn::divide_up(output_bit_sum, 8u);
auto output_num_bytes = basn::divide_up<size_t>(output_bit_sum, 8u);
return basct::cspan<uint8_t>{data, output_num_bytes * n};
}
} // namespace sxt::cbnbck
11 changes: 11 additions & 0 deletions sxt/cbindings/backend/computational_backend_utility.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,15 @@ TEST_CASE("we can make a span for the referenced scalars") {
REQUIRE(span.size() == 4);
REQUIRE(span.data() == data);
}

SECTION("we handle values that would overflow a 32-bit integer") {
output_bit_table = {1, 1};
output_lengths = {
4'294'967'295u,
4'294'967'295u,
};
auto span = make_scalars_span(data, output_bit_table, output_lengths);
REQUIRE(span.size() == 4'294'967'295ul);
REQUIRE(span.data() == data);
}
}
8 changes: 8 additions & 0 deletions sxt/multiexp/pippenger2/variable_length_computation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "sxt/multiexp/pippenger2/variable_length_computation.h"

#include <algorithm>
#include <numeric>

#include "sxt/base/error/assert.h"

Expand Down Expand Up @@ -59,4 +60,11 @@ void compute_product_length_table(basct::span<unsigned>& product_lengths,
}
product_lengths = product_lengths.subspan(0, product_index);
}

//--------------------------------------------------------------------------------------------------
// count_products
//--------------------------------------------------------------------------------------------------
size_t count_products(basct::cspan<unsigned> output_bit_table) noexcept {
return std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0ull);
}
} // namespace sxt::mtxpp2
5 changes: 5 additions & 0 deletions sxt/multiexp/pippenger2/variable_length_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@ void compute_product_length_table(basct::span<unsigned>& product_lengths,
basct::cspan<unsigned> bit_widths,
basct::cspan<unsigned> output_lengths, unsigned first,
unsigned length) noexcept;

//--------------------------------------------------------------------------------------------------
// count_products
//--------------------------------------------------------------------------------------------------
size_t count_products(basct::cspan<unsigned> output_bit_table) noexcept;
} // namespace sxt::mtxpp2
17 changes: 17 additions & 0 deletions sxt/multiexp/pippenger2/variable_length_computation.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,20 @@ TEST_CASE("we can fill in the table of product lengths") {
REQUIRE(product_lengths[1] == 5);
}
}

TEST_CASE("we can count the number of products") {
std::vector<unsigned> output_bit_table;

SECTION("we can count a single output") {
output_bit_table = {123};
REQUIRE(count_products(output_bit_table) == 123);
}

SECTION("we can count entries that would overflow a 32-bit integer") {
output_bit_table = {
4'294'967'295u,
4'294'967'295u,
};
REQUIRE(count_products(output_bit_table) == 8'589'934'590ul);
}
}
4 changes: 2 additions & 2 deletions sxt/multiexp/pippenger2/variable_length_multiexponentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ multiexponentiate_impl(basct::span<T> res, const partition_table_accessor<U>& ac
basct::cspan<unsigned> output_lengths, basct::cspan<uint8_t> scalars,
const multiexponentiate_options& options) noexcept {
auto num_outputs = res.size();
auto num_products = std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0u);
auto num_products = count_products(output_bit_table);
auto num_output_bytes = basn::divide_up<size_t>(num_products, 8);
if (num_outputs == 0) {
co_return;
Expand Down Expand Up @@ -224,7 +224,7 @@ void multiexponentiate(basct::span<T> res, const partition_table_accessor<U>& ac
basct::cspan<unsigned> output_lengths,
basct::cspan<uint8_t> scalars) noexcept {
auto num_outputs = res.size();
auto num_products = std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0u);
auto num_products = count_products(output_bit_table);
auto num_output_bytes = basn::divide_up<size_t>(num_products, 8);
if (num_outputs == 0) {
return;
Expand Down

0 comments on commit 8e35da9

Please sign in to comment.