diff --git a/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.hpp b/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.hpp index 8444d654b..d6d77c15d 100644 --- a/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.hpp +++ b/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.hpp @@ -183,10 +183,16 @@ class flystel_prime_field_gadget : public gadget> void generate_r1cs_witness(); }; -// get the MDS matrix from the number of columns 2,3 or 4 -template -std::array, NumStateColumns_L> -anemoi_permutation_mds(const FieldT g); +// get the MDS matrix for each allowed dimension: 2,3 or 4 +template +std::array, 2>, 2> anemoi_permutation_mds_2x2( + const libff::Fr g); +template +std::array, 3>, 3> anemoi_permutation_mds_3x3( + const libff::Fr g); +template +std::array, 4>, 4> anemoi_permutation_mds_4x4( + const libff::Fr g); } // namespace libsnark diff --git a/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.tcc b/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.tcc index f2ee272ad..0ec31f8cd 100644 --- a/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.tcc +++ b/libsnark/gadgetlib1/gadgets/hashes/anemoi/anemoi_components.tcc @@ -320,33 +320,38 @@ void flystel_prime_field_gadget::generate_r1cs_witness() this->pb.lc_val(output_y1) = input_x1_value - this->pb.val(a1); } -template -std::array, NumStateColumns_L> -anemoi_permutation_mds(const FieldT g) +template +std::array, 2>, 2> anemoi_permutation_mds_2x2( + const libff::Fr g) +{ + using FieldT = libff::Fr; + const FieldT g2 = g * g; + std::array, 2> M = {{{1, g}, {g, g2 + 1}}}; + return M; +} + +template +std::array, 3>, 3> anemoi_permutation_mds_3x3( + const libff::Fr g) { - static_assert( - (NumStateColumns_L == 2) || (NumStateColumns_L == 3) || - (NumStateColumns_L == 4), - "NumStateColumns_L must be 2,3 or 4"); + using FieldT = libff::Fr; + std::array, 3> M = { + {{g + 1, 1, g + 1}, {1, 1, g}, {g, 1, 1}}}; + return M; +} - std::array, NumStateColumns_L> M; +template +std::array, 4>, 4> anemoi_permutation_mds_4x4( + const libff::Fr g) +{ + using FieldT = libff::Fr; const FieldT g2 = g * g; - if (NumStateColumns_L == 2) { - M = {{1, g}, {g, g2 + 1}}; - return M; - } - if (NumStateColumns_L == 3) { - M = {{g + 1, 1, g + 1}, {1, 1, g}, {g, 1, 1}}; - return M; - } - if (NumStateColumns_L == 4) { - M = { - {1, 1 + g, g, g}, - {g2, g + g2, 1 + g, 1 + 2 * g}, - {g2, g2, 1, 1 + g}, - {1 + g, 1 + 2 * g, g, 1 + g}}; - return M; - } + std::array, 4> M = { + {{1, g + 1, g, g}, + {g2, g + g2, g + 1, g + g + 1}, + {g2, g2, 1, g + 1}, + {g + 1, g + g + 1, g, g + 1}}}; + return M; } } // namespace libsnark diff --git a/libsnark/gadgetlib1/gadgets/hashes/anemoi/tests/test_anemoi_gadget.cpp b/libsnark/gadgetlib1/gadgets/hashes/anemoi/tests/test_anemoi_gadget.cpp index 749e95f2a..1a87a890f 100644 --- a/libsnark/gadgetlib1/gadgets/hashes/anemoi/tests/test_anemoi_gadget.cpp +++ b/libsnark/gadgetlib1/gadgets/hashes/anemoi/tests/test_anemoi_gadget.cpp @@ -6,6 +6,7 @@ * @copyright MIT license (see LICENSE file) *****************************************************************************/ +#include #include #include #include @@ -224,6 +225,34 @@ void test_flystel_prime_field_gadget() libff::print_time("flystel_prime_field_gadget tests successful"); } +template>> +void test_anemoi_permutation_mds() +{ + using FieldT = libff::Fr; + const FieldT g = anemoi_parameters::multiplicative_generator_g; + { + std::array, 2> M_expect = {{{1, 7}, {7, 50}}}; + std::array, 2> M = + anemoi_permutation_mds_2x2(g); + ASSERT_EQ(M, M_expect); + } + { + std::array, 3> M_expect = { + {{8, 1, 8}, {1, 1, 7}, {7, 1, 1}}}; + std::array, 3> M = + anemoi_permutation_mds_3x3(g); + ASSERT_EQ(M, M_expect); + } + { + std::array, 4> M_expect = { + {{1, 8, 7, 7}, {49, 56, 8, 15}, {49, 49, 1, 8}, {8, 15, 7, 8}}}; + std::array, 4> M = + anemoi_permutation_mds_4x4(g); + ASSERT_EQ(M, M_expect); + } + libff::print_time("anemoi_permutation_mds tests successful"); +} + template void test_for_curve() { // Execute all tests for the given curve. @@ -236,6 +265,7 @@ template void test_for_curve() test_flystel_E_power_five_gadget(); test_flystel_E_root_five_gadget(); test_flystel_prime_field_gadget(); + test_anemoi_permutation_mds(); } TEST(TestAnemoiGadget, BLS12_381) { test_for_curve(); }