Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't split inner products that are already on device memory (PROOF-923) #206

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sxt/scalar25/operation/inner_product.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@ xena::future<s25t::element> async_inner_product_impl(basct::cspan<s25t::element>
basct::cspan<s25t::element> rhs,
size_t split_factor, size_t min_chunk_size,
size_t max_chunk_size) noexcept {
SXT_DEBUG_ASSERT(
(basdv::is_host_pointer(lhs.data()) && basdv::is_host_pointer(rhs.data())) ||
(basdv::is_active_device_pointer(lhs.data()) && basdv::is_active_device_pointer(rhs.data())));
auto n = std::min(lhs.size(), rhs.size());
SXT_DEBUG_ASSERT(n > 0);
if (basdv::is_active_device_pointer(lhs.data())) {
co_return co_await async_inner_product_partial(lhs.subspan(0, n), rhs.subspan(0, n));
}
s25t::element res = s25t::element::identity();

basit::split_options split_options{
Expand Down
14 changes: 5 additions & 9 deletions sxt/scalar25/operation/inner_product.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,20 @@ TEST_CASE("we can compute inner products asynchronously on the GPU") {
REQUIRE(res.value() == expected_res);
}

SECTION("async inner product works with both device and host points") {
SECTION("we can split a GPU inner product into smaller chunks") {
size_t n = 100;
make_dataset(a_host, b_host, a_dev, b_dev, rng, n);
auto res1 = async_inner_product(a_dev, b_host);
auto res2 = async_inner_product(a_host, b_dev);
auto res3 = async_inner_product(a_host, b_host);
auto res = async_inner_product_impl(a_host, b_host, 4, 1, 10);
s25t::element expected_res;
inner_product(expected_res, a_host, b_host);
xens::get_scheduler().run();
REQUIRE(res1.value() == expected_res);
REQUIRE(res2.value() == expected_res);
REQUIRE(res3.value() == expected_res);
REQUIRE(res.value() == expected_res);
}

SECTION("we can split a GPU inner product into smaller chunks") {
SECTION("we don't split when inputs are already in device memory") {
size_t n = 100;
make_dataset(a_host, b_host, a_dev, b_dev, rng, n);
auto res = async_inner_product_impl(a_dev, b_host, 4, 1, 10);
auto res = async_inner_product_impl(a_dev, b_dev, 4, 1, 10);
s25t::element expected_res;
inner_product(expected_res, a_host, b_host);
xens::get_scheduler().run();
Expand Down