From 0f1bb23bb1c89babcfe5aa37e80d99d0983301b9 Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 31 Oct 2024 19:56:47 -0700 Subject: [PATCH] fill in gpu driver --- sxt/algorithm/iteration/for_each.h | 3 + sxt/proof/sumcheck/BUILD | 1 + sxt/proof/sumcheck/gpu_driver.cc | 73 ++++++++++++++++++----- sxt/proof/sumcheck/proof_computation.t.cc | 4 +- 4 files changed, 63 insertions(+), 18 deletions(-) diff --git a/sxt/algorithm/iteration/for_each.h b/sxt/algorithm/iteration/for_each.h index 375a7043..0520f0fd 100644 --- a/sxt/algorithm/iteration/for_each.h +++ b/sxt/algorithm/iteration/for_each.h @@ -58,6 +58,9 @@ void launch_for_each_kernel(bast::raw_stream_t stream, F f, unsigned n) noexcept //-------------------------------------------------------------------------------------------------- template xena::future<> for_each(basdv::stream&& stream, F f, unsigned n) noexcept { + if (n == 0) { + return xena::make_ready_future(); + } launch_for_each_kernel(stream, f, n); return xendv::await_and_own_stream(std::move(stream)); } diff --git a/sxt/proof/sumcheck/BUILD b/sxt/proof/sumcheck/BUILD index eb29b94c..632bdb55 100644 --- a/sxt/proof/sumcheck/BUILD +++ b/sxt/proof/sumcheck/BUILD @@ -63,6 +63,7 @@ sxt_cc_component( sxt_cc_component( name = "gpu_driver", impl_deps = [ + ":partial_polynomial_mapper", ":polynomial_mapper", ":polynomial_reducer", ":polynomial_utility", diff --git a/sxt/proof/sumcheck/gpu_driver.cc b/sxt/proof/sumcheck/gpu_driver.cc index aeb79d52..ba574020 100644 --- a/sxt/proof/sumcheck/gpu_driver.cc +++ b/sxt/proof/sumcheck/gpu_driver.cc @@ -34,6 +34,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" +#include "sxt/proof/sumcheck/partial_polynomial_mapper.h" #include "sxt/proof/sumcheck/polynomial_mapper.h" #include "sxt/proof/sumcheck/polynomial_reducer.h" #include "sxt/proof/sumcheck/polynomial_utility.h" @@ -49,8 +50,6 @@ namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- namespace { struct gpu_workspace final : public workspace { - /* basdv::stream stream; */ - /* memr::async_device_resource resource; */ memmg::managed_array mles; memmg::managed_array> product_table; memmg::managed_array product_terms; @@ -136,9 +135,19 @@ xena::future<> gpu_driver::sum(basct::span polynomial, polynomial.size() - 1u <= max_degree_v // clang-format on ); + for (auto& pi : polynomial) { + pi = {}; + } xena::future<> res; + auto n1 = n - mid; + + // sum full terms auto f = [&](std::integral_constant) noexcept { + if (n1 == 0) { + res = xena::make_ready_future(); + return; + } polynomial_mapper mapper{ .mles = work.mles.data(), .product_table = work.product_table.data(), @@ -147,20 +156,38 @@ xena::future<> gpu_driver::sum(basct::span polynomial, .mid = mid, .n = n, }; - auto fut = algr::reduce>(basdv::stream{}, mapper, mid); + auto fut = algr::reduce>(basdv::stream{}, mapper, n1); res = fut.then([&](std::array p) noexcept { for (unsigned i = 0; i < p.size(); ++i) { - polynomial[i] = p[i]; + s25o::add(polynomial[i], polynomial[i], p[i]); } }); }; - std::println("summing {}", n); - auto t1 = std::chrono::steady_clock::now(); basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, f); co_await std::move(res); - auto t2 = std::chrono::steady_clock::now(); - auto elapse = std::chrono::duration_cast(t2 - t1); - std::println("done summing {}: {}", n, elapse.count() / 1.0e6); + + // sum partial terms + auto fp = [&](std::integral_constant) noexcept { + if (n1 == mid) { + res = xena::make_ready_future(); + return; + } + partial_polynomial_mapper mapper{ + .mles = work.mles.data() + n1, + .product_table = work.product_table.data(), + .product_terms = work.product_terms.data(), + .num_products = static_cast(work.product_table.size()), + .n = n, + }; + auto fut = algr::reduce>(basdv::stream{}, mapper, mid - n1); + res = fut.then([&](std::array p) noexcept { + for (unsigned i = 0; i < p.size(); ++i) { + s25o::add(polynomial[i], polynomial[i], p[i]); + } + }); + }; + basn::constexpr_switch<1u, max_degree_v + 1u>(polynomial.size() - 1u, fp); + co_await std::move(res); } //-------------------------------------------------------------------------------------------------- @@ -174,7 +201,7 @@ xena::future<> gpu_driver::fold(workspace& ws, const s25t::element& r) const noe auto num_mles = work.mles.size() / n; SXT_RELEASE_ASSERT( // clang-format off - work.n > mid && work.mles.size() % n == 0 + work.n >= mid && work.mles.size() % n == 0 // clang-format on ); @@ -202,14 +229,28 @@ xena::future<> gpu_driver::fold(workspace& ws, const s25t::element& r) const noe data[i] = val; } }; - std::println("folding {}", n); - auto t1 = std::chrono::steady_clock::now(); co_await algi::for_each(f1, n1); - auto t2 = std::chrono::steady_clock::now(); - auto elapse = std::chrono::duration_cast(t2 - t1); - std::println("done folding {}: {}", n, elapse.count() / 1.0e6); - SXT_RELEASE_ASSERT(n1 == mid, "not implemented yet"); + // f2 + auto f2 = [ + // clang-format off + mles = work.mles.data() + n1, + n = n, + num_mles = num_mles, + one_m_r = one_m_r + // clang-format on + ] __device__ + __host__(unsigned /*n1*/, unsigned i) noexcept { + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + auto data = mles + n * mle_index; + auto val = data[i]; + s25o::mul(val, val, one_m_r); + data[i] = val; + } + }; + if (n1 != mid) { + co_await algi::for_each(f2, mid - n1); + } work.n = mid; --work.num_variables; diff --git a/sxt/proof/sumcheck/proof_computation.t.cc b/sxt/proof/sumcheck/proof_computation.t.cc index 863939e6..633c2f9d 100644 --- a/sxt/proof/sumcheck/proof_computation.t.cc +++ b/sxt/proof/sumcheck/proof_computation.t.cc @@ -35,8 +35,8 @@ using s25t::operator""_s25; TEST_CASE("we can create a sumcheck proof") { prft::transcript transcript{"abc"}; - cpu_driver drv; - /* gpu_driver drv; */ + /* cpu_driver drv; */ + gpu_driver drv; std::vector polynomials(2); std::vector evaluation_point(1); std::vector mles = {