From 2be9ebf5844c9e2d63b2bb56c8b5bee30c04ab1e Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 29 Oct 2024 19:07:55 -0700 Subject: [PATCH] add sumcheck operations for cpu --- sxt/proof/sumcheck/BUILD | 42 ++++++ sxt/proof/sumcheck/cpu_driver.cc | 159 +++++++++++++++++++++ sxt/proof/sumcheck/cpu_driver.h | 37 +++++ sxt/proof/sumcheck/cpu_driver.t.cc | 116 +++++++++++++++ sxt/proof/sumcheck/driver.cc | 17 +++ sxt/proof/sumcheck/driver.h | 47 ++++++ sxt/proof/sumcheck/polynomial_utility.cc | 38 +++++ sxt/proof/sumcheck/polynomial_utility.h | 7 + sxt/proof/sumcheck/polynomial_utility.t.cc | 11 ++ sxt/proof/sumcheck/workspace.cc | 17 +++ sxt/proof/sumcheck/workspace.h | 27 ++++ 11 files changed, 518 insertions(+) create mode 100644 sxt/proof/sumcheck/cpu_driver.cc create mode 100644 sxt/proof/sumcheck/cpu_driver.h create mode 100644 sxt/proof/sumcheck/cpu_driver.t.cc create mode 100644 sxt/proof/sumcheck/driver.cc create mode 100644 sxt/proof/sumcheck/driver.h create mode 100644 sxt/proof/sumcheck/workspace.cc create mode 100644 sxt/proof/sumcheck/workspace.h diff --git a/sxt/proof/sumcheck/BUILD b/sxt/proof/sumcheck/BUILD index ea5691e0..4617cf4c 100644 --- a/sxt/proof/sumcheck/BUILD +++ b/sxt/proof/sumcheck/BUILD @@ -3,6 +3,48 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "workspace", + with_test = False, +) + +sxt_cc_component( + name = "driver", + with_test = False, + deps = [ + ":workspace", + "//sxt/base/container:span", + "//sxt/execution/async:future_fwd", + ], +) + +sxt_cc_component( + name = "cpu_driver", + impl_deps = [ + ":polynomial_utility", + "//sxt/base/container:stack_array", + "//sxt/base/error:panic", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:future", + "//sxt/memory/management:managed_array", + "//sxt/scalar25/operation:mul", + "//sxt/scalar25/operation:sub", + "//sxt/scalar25/operation:muladd", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", + ], + test_deps = [ + "//sxt/execution/async:future", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", + ], + deps = [ + ":driver", + ":workspace", + ], +) + sxt_cc_component( name = "transcript_utility", impl_deps = [ diff --git a/sxt/proof/sumcheck/cpu_driver.cc b/sxt/proof/sumcheck/cpu_driver.cc new file mode 100644 index 00000000..54a03a22 --- /dev/null +++ b/sxt/proof/sumcheck/cpu_driver.cc @@ -0,0 +1,159 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/proof/sumcheck/cpu_driver.h" + +#include +#include + +#include "sxt/base/container/stack_array.h" +#include "sxt/base/error/panic.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/future.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" +#include "sxt/scalar25/operation/mul.h" +#include "sxt/scalar25/operation/muladd.h" +#include "sxt/scalar25/operation/sub.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// cpu_workspace +//-------------------------------------------------------------------------------------------------- +namespace { +struct cpu_workspace final : public workspace { + memmg::managed_array mles; + basct::cspan> product_table; + basct::cspan product_terms; + unsigned n; + unsigned num_variables; +}; +} // namespace + +//-------------------------------------------------------------------------------------------------- +// make_workspace +//-------------------------------------------------------------------------------------------------- +xena::future> +cpu_driver::make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept { + auto res = std::make_unique(); + res->mles = memmg::managed_array{mles.begin(), mles.end()}; + res->product_table = product_table; + res->product_terms = product_terms; + res->n = n; + res->num_variables = std::max(basn::ceil_log2(n), 1); + return xena::make_ready_future>(std::move(res)); +} + +//-------------------------------------------------------------------------------------------------- +// sum +//-------------------------------------------------------------------------------------------------- +xena::future<> cpu_driver::sum(basct::span polynomial, + workspace& ws) const noexcept { + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + SXT_RELEASE_ASSERT(work.n >= mid); + + auto mles = work.mles.data(); + auto product_table = work.product_table; + auto product_terms = work.product_terms; + + for (auto& val : polynomial) { + val = {}; + } + + // expand paired terms + auto n1 = work.n - mid; + for (unsigned i = 0; i < n1; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + SXT_RELEASE_ASSERT(num_terms < polynomial.size()); + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element); + expand_products(p, mles + i, n, mid, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; + } + } + + // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) + for (unsigned i = n1; i < mid; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element); + partial_expand_products(p, mles + i, n, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; + } + } + + return xena::make_ready_future(); +} + +//-------------------------------------------------------------------------------------------------- +// fold +//-------------------------------------------------------------------------------------------------- +xena::future<> cpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { + using s25t::operator""_s25; + + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + auto mles = work.mles.data(); + s25t::element one_m_r = 0x1_s25; + s25o::sub(one_m_r, one_m_r, r); + auto n1 = work.n - mid; + for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { + auto data = mles + n * mle_index; + + // fold paired terms + for (unsigned i = 0; i < n1; ++i) { + auto val = data[i]; + s25o::mul(val, val, one_m_r); + s25o::muladd(val, r, data[mid + i], val); + data[i] = val; + } + + // fold terms paired with zero + for (unsigned i = n1; i < mid; ++i) { + auto val = data[i]; + s25o::mul(val, val, one_m_r); + data[i] = val; + } + } + + work.n = mid; + --work.num_variables; + work.mles.shrink(num_mles * mid); + return xena::make_ready_future(); +} +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.h b/sxt/proof/sumcheck/cpu_driver.h new file mode 100644 index 00000000..b017396d --- /dev/null +++ b/sxt/proof/sumcheck/cpu_driver.h @@ -0,0 +1,37 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "sxt/proof/sumcheck/driver.h" + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// cpu_driver +//-------------------------------------------------------------------------------------------------- +class cpu_driver final : public driver { +public: + // driver + xena::future> + make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override; + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; + + xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; +}; +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.t.cc b/sxt/proof/sumcheck/cpu_driver.t.cc new file mode 100644 index 00000000..fbb3ce85 --- /dev/null +++ b/sxt/proof/sumcheck/cpu_driver.t.cc @@ -0,0 +1,116 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/proof/sumcheck/cpu_driver.h" + +#include + +#include "sxt/base/test/unit_test.h" +#include "sxt/execution/async/future.h" +#include "sxt/proof/sumcheck/workspace.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk; +using s25t::operator""_s25; + +TEST_CASE("we can perform the primitive operations for sumcheck proofs") { + std::vector mles; + std::vector> product_table{ + {0x1_s25, 1}, + }; + std::vector product_terms = {0}; + + std::vector p(2); + cpu_driver drv; + + SECTION("we can sum a polynomial with n = 1") { + std::vector mles = {0x123_s25}; + auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value(); + auto fut = drv.sum(p, *ws); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == -mles[0]); + } + + SECTION("we can sum a polynomial with a non-unity multiplier") { + std::vector mles = {0x123_s25}; + product_table[0].first = 0x2_s25; + auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value(); + auto fut = drv.sum(p, *ws); + REQUIRE(fut.ready()); + REQUIRE(p[0] == 0x2_s25 * mles[0]); + REQUIRE(p[1] == -0x2_s25 * mles[0]); + } + + SECTION("we can sum a polynomial with n = 2") { + std::vector mles = {0x123_s25, 0x456_s25}; + auto ws = drv.make_workspace(mles, product_table, product_terms, 2).value(); + auto fut = drv.sum(p, *ws); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } + + SECTION("we can sum a polynomial with two MLEs added together") { + std::vector mles = {0x123_s25, 0x456_s25}; + std::vector> product_table{ + {0x1_s25, 1}, + {0x1_s25, 1}, + }; + std::vector product_terms = {0, 1}; + + auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value(); + auto fut = drv.sum(p, *ws); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] + mles[1]); + REQUIRE(p[1] == -mles[0] - mles[1]); + } + + SECTION("we can sum a polynomial with two MLEs multiplied together") { + std::vector mles = {0x123_s25, 0x456_s25}; + std::vector> product_table{ + {0x1_s25, 2}, + }; + std::vector product_terms = {0, 1}; + p.resize(3); + + auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value(); + auto fut = drv.sum(p, *ws); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] * mles[1]); + REQUIRE(p[1] == -mles[0] * mles[1] - mles[1] * mles[0]); + REQUIRE(p[2] == mles[0] * mles[1]); + } + + SECTION("we can fold mles") { + std::vector mles = {0x123_s25, 0x456_s25, 0x789_s25}; + auto ws = drv.make_workspace(mles, product_table, product_terms, 3).value(); + auto r = 0xabc123_s25; + auto fut = drv.fold(*ws, r); + REQUIRE(fut.ready()); + fut = drv.sum(p, *ws); + REQUIRE(fut.ready()); + + mles[0] = (0x1_s25 - r) * mles[0] + r * mles[2]; + mles[1] = (0x1_s25 - r) * mles[1]; + + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } +} diff --git a/sxt/proof/sumcheck/driver.cc b/sxt/proof/sumcheck/driver.cc new file mode 100644 index 00000000..6e46927f --- /dev/null +++ b/sxt/proof/sumcheck/driver.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/proof/sumcheck/driver.h" diff --git a/sxt/proof/sumcheck/driver.h b/sxt/proof/sumcheck/driver.h new file mode 100644 index 00000000..422bb88a --- /dev/null +++ b/sxt/proof/sumcheck/driver.h @@ -0,0 +1,47 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "sxt/base/container/span.h" +#include "sxt/execution/async/future_fwd.h" +#include "sxt/proof/sumcheck/workspace.h" + +namespace sxt::s25t { +class element; +} + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// driver +//-------------------------------------------------------------------------------------------------- +class driver { +public: + virtual ~driver() noexcept = default; + + virtual xena::future> + make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept = 0; + + virtual xena::future<> sum(basct::span polynomial, + workspace& ws) const noexcept = 0; + + virtual xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept = 0; +}; +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.cc b/sxt/proof/sumcheck/polynomial_utility.cc index 8c5ffb2b..4ba3d916 100644 --- a/sxt/proof/sumcheck/polynomial_utility.cc +++ b/sxt/proof/sumcheck/polynomial_utility.cc @@ -21,6 +21,7 @@ #include "sxt/scalar25/operation/add.h" #include "sxt/scalar25/operation/mul.h" #include "sxt/scalar25/operation/muladd.h" +#include "sxt/scalar25/operation/neg.h" #include "sxt/scalar25/operation/sub.h" #include "sxt/scalar25/type/element.h" @@ -96,4 +97,41 @@ void expand_products(basct::span p, const s25t::element* mles, un s25o::mul(p[i + 1u], c_prev, b); } } + +//-------------------------------------------------------------------------------------------------- +// partial_expand_products +//-------------------------------------------------------------------------------------------------- +CUDA_CALLABLE +void partial_expand_products(basct::span p, const s25t::element* mles, unsigned n, + basct::cspan terms) noexcept { + auto num_terms = terms.size(); + assert( + // clang-format off + num_terms > 0 && + p.size() == num_terms + 1u + // clang-format on + ); + s25t::element a, b; + auto mle_index = terms[0]; + a = *(mles + mle_index * n); + s25o::neg(b, a); + p[0] = a; + p[1] = b; + + for (unsigned i = 1; i < num_terms; ++i) { + auto mle_index = terms[i]; + a = *(mles + mle_index * n); + s25o::neg(b, a); + + auto c_prev = p[0]; + s25o::mul(p[0], c_prev, a); + for (unsigned pow = 1u; pow < i + 1u; ++pow) { + auto c = p[pow]; + s25o::mul(p[pow], c, a); + s25o::muladd(p[pow], c_prev, b, p[pow]); + c_prev = c; + } + s25o::mul(p[i + 1u], c_prev, b); + } +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.h b/sxt/proof/sumcheck/polynomial_utility.h index c6c98613..6754c4d7 100644 --- a/sxt/proof/sumcheck/polynomial_utility.h +++ b/sxt/proof/sumcheck/polynomial_utility.h @@ -45,4 +45,11 @@ void evaluate_polynomial(s25t::element& e, basct::cspan polynomia CUDA_CALLABLE void expand_products(basct::span p, const s25t::element* mles, unsigned n, unsigned step, basct::cspan terms) noexcept; + +//-------------------------------------------------------------------------------------------------- +// partial_expand_products +//-------------------------------------------------------------------------------------------------- +CUDA_CALLABLE +void partial_expand_products(basct::span p, const s25t::element* mles, unsigned n, + basct::cspan terms) noexcept; } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.t.cc b/sxt/proof/sumcheck/polynomial_utility.t.cc index b1ee18b9..281e65e3 100644 --- a/sxt/proof/sumcheck/polynomial_utility.t.cc +++ b/sxt/proof/sumcheck/polynomial_utility.t.cc @@ -87,6 +87,17 @@ TEST_CASE("we can expand a product of MLEs") { REQUIRE(p[1] == mles[1] - mles[0]); } + SECTION("we can partially expand MLEs (where some terms are assumed to be zero)") { + mles = {0x123_s25, 0x0_s25}; + p.resize(2); + terms = {0}; + partial_expand_products(p, mles.data(), 1, terms); + + std::vector expected(2); + expand_products(expected, mles.data(), 2, 1, terms); + REQUIRE(p == expected); + } + SECTION("we can expand two MLEs") { p.resize(3); mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25}; diff --git a/sxt/proof/sumcheck/workspace.cc b/sxt/proof/sumcheck/workspace.cc new file mode 100644 index 00000000..d356b4af --- /dev/null +++ b/sxt/proof/sumcheck/workspace.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/proof/sumcheck/workspace.h" diff --git a/sxt/proof/sumcheck/workspace.h b/sxt/proof/sumcheck/workspace.h new file mode 100644 index 00000000..edacbd2b --- /dev/null +++ b/sxt/proof/sumcheck/workspace.h @@ -0,0 +1,27 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// workspace +//-------------------------------------------------------------------------------------------------- +class workspace { +public: + virtual ~workspace() noexcept = default; +}; +} // namespace sxt::prfsk