Skip to content

Commit

Permalink
refactor gpu driver
Browse files Browse the repository at this point in the history
  • Loading branch information
rnburn committed Nov 1, 2024
1 parent ba76004 commit f8a9b66
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions sxt/proof/sumcheck/gpu_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,6 @@ struct gpu_workspace final : public workspace {
unsigned n;
unsigned num_variables;

#if 0
gpu_workspace(basct::cspan<s25t::element> mles_p,
basct::cspan<std::pair<s25t::element, unsigned>> product_table_p,
basct::cspan<unsigned> product_terms_p, unsigned np) noexcept
: resource{stream}, mles{mles_p.size(), &resource},
product_table{product_table_p.size(), &resource},
product_terms{product_terms_p.size(), &resource}, n{np},
num_variables{static_cast<unsigned>(basn::ceil_log2(np))} {
basdv::async_copy_host_to_device(mles, mles_p, stream);
basdv::async_copy_host_to_device(product_table, product_table_p, stream);
basdv::async_copy_host_to_device(product_terms, product_terms_p, stream);
}
#endif

gpu_workspace() noexcept
: mles{memr::get_device_resource()}, product_table{memr::get_device_resource()},
product_terms{memr::get_device_resource()} {}
Expand Down Expand Up @@ -139,13 +125,13 @@ xena::future<> gpu_driver::sum(basct::span<s25t::element> polynomial,
pi = {};
}

xena::future<> res;
auto n1 = n - mid;

// sum full terms
auto f = [&]<unsigned MaxDegree>(std::integral_constant<unsigned, MaxDegree>) noexcept {
xena::future<> fut1;
auto f1 = [&]<unsigned MaxDegree>(std::integral_constant<unsigned, MaxDegree>) noexcept {
if (n1 == 0) {
res = xena::make_ready_future();
fut1 = xena::make_ready_future();
return;
}
polynomial_mapper<MaxDegree> mapper{
Expand All @@ -157,19 +143,19 @@ xena::future<> gpu_driver::sum(basct::span<s25t::element> polynomial,
.n = n,
};
auto fut = algr::reduce<polynomial_reducer<MaxDegree>>(basdv::stream{}, mapper, n1);
res = fut.then([&](std::array<s25t::element, MaxDegree + 1u> p) noexcept {
fut1 = fut.then([&](std::array<s25t::element, MaxDegree + 1u> p) noexcept {
for (unsigned i = 0; i < p.size(); ++i) {
s25o::add(polynomial[i], polynomial[i], p[i]);
}
});
};
basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, f);
co_await std::move(res);
basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, f1);

// sum partial terms
auto fp = [&]<unsigned MaxDegree>(std::integral_constant<unsigned, MaxDegree>) noexcept {
xena::future<> fut2;
auto f2 = [&]<unsigned MaxDegree>(std::integral_constant<unsigned, MaxDegree>) noexcept {
if (n1 == mid) {
res = xena::make_ready_future();
fut2 = xena::make_ready_future();
return;
}
partial_polynomial_mapper<MaxDegree> mapper{
Expand All @@ -180,14 +166,17 @@ xena::future<> gpu_driver::sum(basct::span<s25t::element> polynomial,
.n = n,
};
auto fut = algr::reduce<polynomial_reducer<MaxDegree>>(basdv::stream{}, mapper, mid - n1);
res = fut.then([&](std::array<s25t::element, MaxDegree + 1u> p) noexcept {
fut2 = fut.then([&](std::array<s25t::element, MaxDegree + 1u> p) noexcept {
for (unsigned i = 0; i < p.size(); ++i) {
s25o::add(polynomial[i], polynomial[i], p[i]);
}
});
};
basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, fp);
co_await std::move(res);
basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, f2);

// await results
co_await std::move(fut1);
co_await std::move(fut2);
}

//--------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit f8a9b66

Please sign in to comment.