Skip to content

Commit

Permalink
add polynomial mappers
Browse files Browse the repository at this point in the history
  • Loading branch information
rnburn committed Nov 1, 2024
1 parent f32d350 commit 85d3376
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 0 deletions.
36 changes: 36 additions & 0 deletions sxt/proof/sumcheck/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,42 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "partial_polynomial_mapper",
test_deps = [
"//sxt/algorithm/base:mapper",
"//sxt/base/test:unit_test",
"//sxt/scalar25/operation:overload",
],
deps = [
":polynomial_utility",
"//sxt/base/macro:cuda_callable",
"//sxt/scalar25/operation:add",
"//sxt/scalar25/operation:mul",
"//sxt/scalar25/operation:muladd",
"//sxt/scalar25/type:element",
"//sxt/scalar25/type:literal",
],
)

sxt_cc_component(
name = "polynomial_mapper",
test_deps = [
"//sxt/algorithm/base:mapper",
"//sxt/base/test:unit_test",
"//sxt/scalar25/operation:overload",
],
deps = [
":polynomial_utility",
"//sxt/base/macro:cuda_callable",
"//sxt/scalar25/operation:add",
"//sxt/scalar25/operation:mul",
"//sxt/scalar25/operation:muladd",
"//sxt/scalar25/type:element",
"//sxt/scalar25/type:literal",
],
)

sxt_cc_component(
name = "polynomial_utility",
impl_deps = [
Expand Down
17 changes: 17 additions & 0 deletions sxt/proof/sumcheck/partial_polynomial_mapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/partial_polynomial_mapper.h"
76 changes: 76 additions & 0 deletions sxt/proof/sumcheck/partial_polynomial_mapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <array>
#include <cassert>
#include <utility>

#include "sxt/base/macro/cuda_callable.h"
#include "sxt/proof/sumcheck/polynomial_utility.h"
#include "sxt/scalar25/operation/add.h"
#include "sxt/scalar25/operation/mul.h"
#include "sxt/scalar25/operation/muladd.h"
#include "sxt/scalar25/type/element.h"

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// partial_polynomial_mapper
//--------------------------------------------------------------------------------------------------
template <unsigned MaxDegree> struct partial_polynomial_mapper {
using value_type = std::array<s25t::element, MaxDegree + 1u>;

CUDA_CALLABLE
value_type map_index(unsigned index) const noexcept {
value_type res;
this->map_index(res, index);
return res;
}

CUDA_CALLABLE
void map_index(value_type& p, unsigned index) const noexcept {
auto mle_data = mles + index;
auto terms_data = product_terms;
s25t::element prod[MaxDegree + 1u];

// first iteration
assert(num_products > 0);
auto [mult, num_terms] = product_table[0];
partial_expand_products({prod, num_terms + 1u}, mle_data, n, {terms_data, num_terms});
terms_data += num_terms;
for (unsigned i = 0; i < num_terms + 1; ++i) {
s25o::mul(p[i], mult, prod[i]);
}

// remaining iterations
for (unsigned product_index = 1; product_index < num_products; ++product_index) {
auto [mult, num_terms] = product_table[product_index];
partial_expand_products({prod, num_terms + 1u}, mle_data, n, {terms_data, num_terms});
terms_data += num_terms;
for (unsigned i = 0; i < num_terms + 1; ++i) {
s25o::muladd(p[i], mult, prod[i], p[i]);
}
}
}

const s25t::element* __restrict__ mles;
const std::pair<s25t::element, unsigned>* __restrict__ product_table;
const unsigned* __restrict__ product_terms;
unsigned num_products;
unsigned n;
};
} // namespace sxt::prfsk
51 changes: 51 additions & 0 deletions sxt/proof/sumcheck/partial_polynomial_mapper.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/partial_polynomial_mapper.h"

#include <vector>

#include "sxt/algorithm/base/mapper.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/scalar25/operation/overload.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using namespace sxt::prfsk;
using s25t::operator""_s25;

TEST_CASE("we can map an index to expanded MLE products") {
REQUIRE(algb::mapper<partial_polynomial_mapper<2>>);

std::vector<s25t::element> mles = {0x123_s25, 0x456_s25};
std::vector<std::pair<s25t::element, unsigned>> product_table = {
{0x1_s25, 1},
};
std::vector<unsigned> product_terms = {0};

SECTION("we can map an index to an expanded polynomial") {
partial_polynomial_mapper<1> mapper{
.mles = mles.data(),
.product_table = product_table.data(),
.product_terms = product_terms.data(),
.num_products = 1,
.n = 2,
};
auto p = mapper.map_index(0);
REQUIRE(p[0] == mles[0]);
REQUIRE(p[1] == -mles[0]);
}
}
17 changes: 17 additions & 0 deletions sxt/proof/sumcheck/polynomial_mapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/polynomial_mapper.h"
77 changes: 77 additions & 0 deletions sxt/proof/sumcheck/polynomial_mapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <array>
#include <cassert>
#include <utility>

#include "sxt/base/macro/cuda_callable.h"
#include "sxt/proof/sumcheck/polynomial_utility.h"
#include "sxt/scalar25/operation/add.h"
#include "sxt/scalar25/operation/mul.h"
#include "sxt/scalar25/operation/muladd.h"
#include "sxt/scalar25/type/element.h"

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// polynomial_mapper
//--------------------------------------------------------------------------------------------------
template <unsigned MaxDegree> struct polynomial_mapper {
using value_type = std::array<s25t::element, MaxDegree + 1u>;

CUDA_CALLABLE
value_type map_index(unsigned index) const noexcept {
value_type res;
this->map_index(res, index);
return res;
}

CUDA_CALLABLE
void map_index(value_type& p, unsigned index) const noexcept {
auto mle_data = mles + index;
auto terms_data = product_terms;
s25t::element prod[MaxDegree + 1u];

// first iteration
assert(num_products > 0);
auto [mult, num_terms] = product_table[0];
expand_products({prod, num_terms + 1u}, mle_data, n, mid, {terms_data, num_terms});
terms_data += num_terms;
for (unsigned i = 0; i < num_terms + 1; ++i) {
s25o::mul(p[i], mult, prod[i]);
}

// remaining iterations
for (unsigned product_index = 1; product_index < num_products; ++product_index) {
auto [mult, num_terms] = product_table[product_index];
expand_products({prod, num_terms + 1u}, mle_data, n, mid, {terms_data, num_terms});
terms_data += num_terms;
for (unsigned i = 0; i < num_terms + 1; ++i) {
s25o::muladd(p[i], mult, prod[i], p[i]);
}
}
}

const s25t::element* __restrict__ mles;
const std::pair<s25t::element, unsigned>* __restrict__ product_table;
const unsigned* __restrict__ product_terms;
unsigned num_products;
unsigned mid;
unsigned n;
};
} // namespace sxt::prfsk
52 changes: 52 additions & 0 deletions sxt/proof/sumcheck/polynomial_mapper.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/polynomial_mapper.h"

#include <vector>

#include "sxt/algorithm/base/mapper.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/scalar25/operation/overload.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using namespace sxt::prfsk;
using s25t::operator""_s25;

TEST_CASE("we can map an index to expanded MLE products") {
REQUIRE(algb::mapper<polynomial_mapper<2>>);

std::vector<s25t::element> mles = {0x123_s25, 0x456_s25};
std::vector<std::pair<s25t::element, unsigned>> product_table = {
{0x1_s25, 1},
};
std::vector<unsigned> product_terms = {0};

SECTION("we can map an index to an expanded polynomial") {
polynomial_mapper<1> mapper{
.mles = mles.data(),
.product_table = product_table.data(),
.product_terms = product_terms.data(),
.num_products = 1,
.mid = 1,
.n = 2,
};
auto p = mapper.map_index(0);
REQUIRE(p[0] == mles[0]);
REQUIRE(p[1] == mles[1] - mles[0]);
}
}

0 comments on commit 85d3376

Please sign in to comment.