Skip to content

Commit

Permalink
fill in gpu driver
Browse files Browse the repository at this point in the history
  • Loading branch information
rnburn committed Nov 1, 2024
1 parent 0e3e4b4 commit 0f1bb23
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 18 deletions.
3 changes: 3 additions & 0 deletions sxt/algorithm/iteration/for_each.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void launch_for_each_kernel(bast::raw_stream_t stream, F f, unsigned n) noexcept
//--------------------------------------------------------------------------------------------------
template <algb::index_functor F>
xena::future<> for_each(basdv::stream&& stream, F f, unsigned n) noexcept {
if (n == 0) {
return xena::make_ready_future();
}
launch_for_each_kernel(stream, f, n);
return xendv::await_and_own_stream(std::move(stream));
}
Expand Down
1 change: 1 addition & 0 deletions sxt/proof/sumcheck/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ sxt_cc_component(
sxt_cc_component(
name = "gpu_driver",
impl_deps = [
":partial_polynomial_mapper",
":polynomial_mapper",
":polynomial_reducer",
":polynomial_utility",
Expand Down
73 changes: 57 additions & 16 deletions sxt/proof/sumcheck/gpu_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/device_resource.h"
#include "sxt/proof/sumcheck/partial_polynomial_mapper.h"
#include "sxt/proof/sumcheck/polynomial_mapper.h"
#include "sxt/proof/sumcheck/polynomial_reducer.h"
#include "sxt/proof/sumcheck/polynomial_utility.h"
Expand All @@ -49,8 +50,6 @@ namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
namespace {
struct gpu_workspace final : public workspace {
/* basdv::stream stream; */
/* memr::async_device_resource resource; */
memmg::managed_array<s25t::element> mles;
memmg::managed_array<std::pair<s25t::element, unsigned>> product_table;
memmg::managed_array<unsigned> product_terms;
Expand Down Expand Up @@ -136,9 +135,19 @@ xena::future<> gpu_driver::sum(basct::span<s25t::element> polynomial,
polynomial.size() - 1u <= max_degree_v
// clang-format on
);
for (auto& pi : polynomial) {
pi = {};
}

xena::future<> res;
auto n1 = n - mid;

// sum full terms
auto f = [&]<unsigned MaxDegree>(std::integral_constant<unsigned, MaxDegree>) noexcept {
if (n1 == 0) {
res = xena::make_ready_future();
return;
}
polynomial_mapper<MaxDegree> mapper{
.mles = work.mles.data(),
.product_table = work.product_table.data(),
Expand All @@ -147,20 +156,38 @@ xena::future<> gpu_driver::sum(basct::span<s25t::element> polynomial,
.mid = mid,
.n = n,
};
auto fut = algr::reduce<polynomial_reducer<MaxDegree>>(basdv::stream{}, mapper, mid);
auto fut = algr::reduce<polynomial_reducer<MaxDegree>>(basdv::stream{}, mapper, n1);
res = fut.then([&](std::array<s25t::element, MaxDegree + 1u> p) noexcept {
for (unsigned i = 0; i < p.size(); ++i) {
polynomial[i] = p[i];
s25o::add(polynomial[i], polynomial[i], p[i]);
}
});
};
std::println("summing {}", n);
auto t1 = std::chrono::steady_clock::now();
basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, f);
co_await std::move(res);
auto t2 = std::chrono::steady_clock::now();
auto elapse = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1);
std::println("done summing {}: {}", n, elapse.count() / 1.0e6);

// sum partial terms
auto fp = [&]<unsigned MaxDegree>(std::integral_constant<unsigned, MaxDegree>) noexcept {
if (n1 == mid) {
res = xena::make_ready_future();
return;
}
partial_polynomial_mapper<MaxDegree> mapper{
.mles = work.mles.data() + n1,
.product_table = work.product_table.data(),
.product_terms = work.product_terms.data(),
.num_products = static_cast<unsigned>(work.product_table.size()),
.n = n,
};
auto fut = algr::reduce<polynomial_reducer<MaxDegree>>(basdv::stream{}, mapper, mid - n1);
res = fut.then([&](std::array<s25t::element, MaxDegree + 1u> p) noexcept {
for (unsigned i = 0; i < p.size(); ++i) {
s25o::add(polynomial[i], polynomial[i], p[i]);
}
});
};
basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, fp);
co_await std::move(res);
}

//--------------------------------------------------------------------------------------------------
Expand All @@ -174,7 +201,7 @@ xena::future<> gpu_driver::fold(workspace& ws, const s25t::element& r) const noe
auto num_mles = work.mles.size() / n;
SXT_RELEASE_ASSERT(
// clang-format off
work.n > mid && work.mles.size() % n == 0
work.n >= mid && work.mles.size() % n == 0
// clang-format on
);

Expand Down Expand Up @@ -202,14 +229,28 @@ xena::future<> gpu_driver::fold(workspace& ws, const s25t::element& r) const noe
data[i] = val;
}
};
std::println("folding {}", n);
auto t1 = std::chrono::steady_clock::now();
co_await algi::for_each(f1, n1);
auto t2 = std::chrono::steady_clock::now();
auto elapse = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1);
std::println("done folding {}: {}", n, elapse.count() / 1.0e6);

SXT_RELEASE_ASSERT(n1 == mid, "not implemented yet");
// f2
auto f2 = [
// clang-format off
mles = work.mles.data() + n1,
n = n,
num_mles = num_mles,
one_m_r = one_m_r
// clang-format on
] __device__
__host__(unsigned /*n1*/, unsigned i) noexcept {
for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) {
auto data = mles + n * mle_index;
auto val = data[i];
s25o::mul(val, val, one_m_r);
data[i] = val;
}
};
if (n1 != mid) {
co_await algi::for_each(f2, mid - n1);
}

work.n = mid;
--work.num_variables;
Expand Down
4 changes: 2 additions & 2 deletions sxt/proof/sumcheck/proof_computation.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ using s25t::operator""_s25;

TEST_CASE("we can create a sumcheck proof") {
prft::transcript transcript{"abc"};
cpu_driver drv;
/* gpu_driver drv; */
/* cpu_driver drv; */
gpu_driver drv;
std::vector<s25t::element> polynomials(2);
std::vector<s25t::element> evaluation_point(1);
std::vector<s25t::element> mles = {
Expand Down

0 comments on commit 0f1bb23

Please sign in to comment.