Skip to content

Commit

Permalink
feat: add a resize method to memmg::managed_array (PROOF-642) (#33)
Browse files Browse the repository at this point in the history
* refactor

* reformat
  • Loading branch information
rnburn authored Oct 18, 2023
1 parent 39d1754 commit 469a304
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 36 deletions.
2 changes: 1 addition & 1 deletion sxt/execution/device/device_viewable.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ event_future<basct::cspan<T>> make_active_device_viewable(memmg::managed_array<T
if (attrs.device == active_device || attrs.kind == basdv::pointer_kind_t::managed) {
return event_future<basct::cspan<T>>{std::move(data)};
}
data_p = memmg::managed_array<T>{data.size(), data_p.get_allocator()};
data_p.resize(data.size());
basdv::stream stream;
basdv::async_memcpy_to_device(data_p.data(), data.data(), sizeof(T) * data.size(), attrs, stream);
basdv::event event;
Expand Down
6 changes: 3 additions & 3 deletions sxt/execution/device/test_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ void add_for_testing(uint64_t* c, bast::raw_stream_t stream, const uint64_t* a,
memmg::managed_array<uint64_t> c_dev{&resource};
auto cp = c;
if (!basdv::is_active_device_pointer(a)) {
a_dev = memmg::managed_array<uint64_t>{static_cast<unsigned>(n), &resource};
a_dev.resize(static_cast<unsigned>(n));
basdv::async_memcpy_host_to_device(a_dev.data(), a, n * sizeof(uint64_t), stream);
a = a_dev.data();
}
if (!basdv::is_active_device_pointer(b)) {
b_dev = memmg::managed_array<uint64_t>{static_cast<unsigned>(n), &resource};
b_dev.resize(static_cast<unsigned>(n));
basdv::async_memcpy_host_to_device(b_dev.data(), b, n * sizeof(uint64_t), stream);
b = b_dev.data();
}
if (!basdv::is_active_device_pointer(c)) {
c_dev = memmg::managed_array<uint64_t>{static_cast<unsigned>(n), &resource};
c_dev.resize(static_cast<unsigned>(n));
cp = c_dev.data();
}
add_impl<<<basn::divide_up(n, 256), 256, 0, stream>>>(cp, a, b, n);
Expand Down
11 changes: 11 additions & 0 deletions sxt/memory/management/managed_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ namespace sxt::memmg {
//--------------------------------------------------------------------------------------------------
// managed_array
//--------------------------------------------------------------------------------------------------
/**
* managed_array is similar to std::pmr::vector except that it doesn't do initialization and
* doesn't have a capacity that's separate from size.
*
* It's intended to be used for device-allocated memory that the host doesn't necessarily have
* access to.
*/
// void
template <> class managed_array<void> {
public:
Expand Down Expand Up @@ -175,6 +182,10 @@ template <class T> class managed_array {

void reset() noexcept { data_.reset(); }

void resize(size_t size_p) noexcept {
*this = memmg::managed_array<T>{size_p, this->get_allocator()};
}

// operator[]
T& operator[](size_t index) noexcept {
SXT_DEBUG_ASSERT(index < this->size());
Expand Down
6 changes: 6 additions & 0 deletions sxt/memory/management/managed_array.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,10 @@ TEST_CASE("managed_array is an allocator-aware container manages an array of "
managed_array<int> expected{1, 2};
REQUIRE(arr == expected);
}

SECTION("we can resize an array") {
managed_array<int> arr{1, 2, 3};
arr.resize(4);
REQUIRE(arr.size() == 4);
}
}
2 changes: 1 addition & 1 deletion sxt/multiexp/curve/naive_multiproduct_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class naive_multiproduct_solver final : public multiproduct_solver<Element> {
if (num_inputs == generators.size()) {
inputs = generators;
} else {
inputs_data = memmg::managed_array<Element>(num_inputs);
inputs_data.resize(num_inputs);
mtxb::filter_generators<Element>(inputs_data, generators, masks);
inputs = inputs_data;
}
Expand Down
2 changes: 1 addition & 1 deletion sxt/multiexp/multiproduct_gpu/multiproduct.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ TEST_CASE("we can compute multiproducts using the GPU") {

SECTION("we can compute products with many terms") {
unsigned n = 1'000;
indexes = memmg::managed_array<unsigned>(n);
indexes.resize(n);
std::iota(indexes.begin(), indexes.end(), 0);
memmg::managed_array<unsigned> product_sizes = {n};
memmg::managed_array<uint64_t> res(product_sizes.size());
Expand Down
10 changes: 2 additions & 8 deletions sxt/multiexp/pippenger/multiproduct_decomposition_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ xena::future<> compute_multiproduct_decomposition(memmg::managed_array<unsigned>
// set up exponents
memmg::managed_array<uint8_t> exponents_data{&resource};
if (!basdv::is_active_device_pointer(exponents.data)) {
exponents_data = memmg::managed_array<uint8_t>{
n * element_num_bytes,
&resource,
};
exponents_data.resize(n * element_num_bytes);
basdv::async_memcpy_host_to_device(exponents_data.data(), exponents.data, n * element_num_bytes,
stream);
exponents.data = exponents_data.data();
Expand All @@ -88,10 +85,7 @@ xena::future<> compute_multiproduct_decomposition(memmg::managed_array<unsigned>
}

// rearrange indexes
indexes = memmg::managed_array<unsigned>{
num_one_bits,
indexes.get_allocator(),
};
indexes.resize(num_one_bits);
SXT_DEBUG_ASSERT(basdv::is_active_device_pointer(indexes.data()));
co_await decompose_exponent_bits(indexes, stream, block_counts, exponents);
}
Expand Down
5 changes: 1 addition & 4 deletions sxt/multiexp/pippenger/multiproduct_decomposition_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,7 @@ xena::future<> count_exponent_bits(memmg::managed_array<unsigned>& block_counts,
auto num_iterations = basn::divide_up(n, num_blocks);
num_blocks = basn::divide_up(n, num_iterations);

block_counts = memmg::managed_array<unsigned>{
num_blocks * element_num_bits,
block_counts.get_allocator(),
};
block_counts.resize(num_blocks * element_num_bits);
SXT_DEBUG_ASSERT(basdv::is_host_pointer(block_counts.data()));

// set up block_counts_dev
Expand Down
2 changes: 1 addition & 1 deletion sxt/multiexp/pippenger/test_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test_driver::compute_multiproduct(mtxi::index_table&& multiproduct_table,
if (num_inputs == generators.size()) {
inputs = generators_p;
} else {
inputs_data = memmg::managed_array<uint64_t>(num_inputs);
inputs_data.resize(num_inputs);
mtxb::filter_generators<uint64_t>(inputs_data, generators_p, masks);
inputs = inputs_data;
}
Expand Down
15 changes: 3 additions & 12 deletions sxt/proof/inner_product/gpu_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,15 @@ gpu_driver::make_workspace(const proof_descriptor& descriptor,
res->round_index = 0;

// a_vector
res->a_vector = memmg::managed_array<s25t::element>{
a_vector.size(),
alloc,
};
res->a_vector.resize(a_vector.size());
basdv::async_copy_host_to_device(res->a_vector, a_vector, stream);

// b_vector
res->b_vector = memmg::managed_array<s25t::element>{
descriptor.b_vector.size(),
alloc,
};
res->b_vector.resize(descriptor.b_vector.size());
basdv::async_copy_host_to_device(res->b_vector, descriptor.b_vector, stream);

// g_vector
res->g_vector = memmg::managed_array<c21t::element_p3>{
descriptor.g_vector.size(),
alloc,
};
res->g_vector.resize(descriptor.g_vector.size());
basdv::async_copy_host_to_device(res->g_vector, descriptor.g_vector, stream);

return xendv::await_and_own_stream(std::move(stream), std::unique_ptr<workspace>{std::move(res)});
Expand Down
2 changes: 1 addition & 1 deletion sxt/scalar25/operation/inner_product.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ xena::future<s25t::element> async_inner_product(basct::cspan<s25t::element> lhs,
auto is_device_rhs = basdv::is_active_device_pointer(rhs.data());
buffer_size = (static_cast<size_t>(!is_device_lhs) + static_cast<size_t>(!is_device_rhs)) * n;
if (buffer_size > 0) {
device_data = memmg::managed_array<s25t::element>{buffer_size, &resource};
device_data.resize(buffer_size);
}
auto data = device_data.data();
if (!is_device_lhs) {
Expand Down
8 changes: 4 additions & 4 deletions sxt/scalar25/operation/inner_product.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ static void make_dataset(memmg::managed_array<s25t::element>& a_host,
memmg::managed_array<s25t::element>& a_dev,
memmg::managed_array<s25t::element>& b_dev,
basn::fast_random_number_generator& rng, size_t n) noexcept {
a_host = memmg::managed_array<s25t::element>(n);
b_host = memmg::managed_array<s25t::element>(n);
a_dev = memmg::managed_array<s25t::element>(n, memr::get_device_resource());
b_dev = memmg::managed_array<s25t::element>(n, memr::get_device_resource());
a_host.resize(n);
b_host.resize(n);
a_dev.resize(n);
b_dev.resize(n);
s25rn::generate_random_elements(a_host, rng);
s25rn::generate_random_elements(b_host, rng);
basdv::memcpy_host_to_device(a_dev.data(), a_host.data(), n * sizeof(s25t::element));
Expand Down

0 comments on commit 469a304

Please sign in to comment.