From 71508ea249552a93bf44040a564c7fb034f90b0d Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Fri, 27 Sep 2024 17:37:41 +0900 Subject: [PATCH 1/2] improve performance of mul/wasm a little by reducing conversion --- Makefile | 2 +- include/mcl/bint.hpp | 18 ++++++++--------- src/bint_impl.hpp | 48 ++++++++++++++++---------------------------- src/low_func.hpp | 15 +++++++++++--- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/Makefile b/Makefile index 4053d466..feaeb0fe 100644 --- a/Makefile +++ b/Makefile @@ -443,7 +443,7 @@ endif # test bin/emu: - $(CXX) -g -o $@ src/fp.cpp src/bn_c384_256.cpp test/bn_c384_256_test.cpp -DMCL_DONT_USE_XBYAK -DMCL_SIZEOF_UNIT=$(MCL_SIZEOF_UNIT) -DMCL_MAX_BIT_SIZE=384 -I./include -DMCL_BINT_ASM=0 -DMCL_MSM=0 + $(CXX) -g -o $@ src/fp.cpp src/bn_c384_256.cpp test/bn_c384_256_test.cpp -DMCL_DONT_USE_XBYAK -DMCL_SIZEOF_UNIT=$(MCL_SIZEOF_UNIT) -DMCL_MAX_BIT_SIZE=384 -I./include -DMCL_BINT_ASM=0 -DMCL_MSM=0 $(CFLAGS_USER) bin/pairing_c_min.exe: sample/pairing_c.c include/mcl/vint.hpp src/fp.cpp include/mcl/bn.hpp $(CXX) -std=c++03 -O3 -g -fno-threadsafe-statics -fno-exceptions -fno-rtti -o $@ sample/pairing_c.c src/fp.cpp src/bn_c384_256.cpp -I./include -DXBYAK_NO_EXCEPTION -DMCL_SIZEOF_UNIT=$(MCL_SIZEOF_UNIT) -DMCL_MAX_BIT_SIZE=384 -DCYBOZU_DONT_USE_STRING -DCYBOZU_DONT_USE_EXCEPTION -DNDEBUG -DMCL_BINT_ASM=0 -DMCL_MSM=0 # -DMCL_DONT_USE_CSPRNG bin/ecdsa-emu: diff --git a/include/mcl/bint.hpp b/include/mcl/bint.hpp index 7a601923..a3ba627d 100644 --- a/include/mcl/bint.hpp +++ b/include/mcl/bint.hpp @@ -143,15 +143,15 @@ inline uint64_t divUnit1(uint64_t *pr, uint64_t H, uint64_t L, uint64_t y) // z[N] = x[N] + y[N] and return CF(0 or 1) templateUnit addT(Unit *z, const Unit *x, const Unit *y); // z[N] = x[N] - y[N] and return CF(0 or 1) -templateUnit subT(Unit *z, const Unit *x, const Unit *y); +templateUnit subT(Unit *z, const T *x, const Unit *y); // z[N] = x[N] + y[N]. assume x, y are Not Full bit templatevoid addNFT(Unit *z, const Unit *x, const Unit *y); // z[N] = x[N] - y[N] and return CF(0 or 1). assume x, y are Not Full bit templateUnit subNFT(Unit *z, const Unit *x, const Unit *y); // [ret:z[N]] = x[N] * y -templateUnit mulUnitT(Unit *z, const Unit *x, Unit y); +templateUnit mulUnitT(T *z, const Unit *x, Unit y); // [ret:z[N]] = z[N] + x[N] * y -templateUnit mulUnitAddT(Unit *z, const Unit *x, Unit y); +templateUnit mulUnitAddT(T *z, const Unit *x, Unit y); // z[2N] = x[N] * y[N] templatevoid mulT(Unit *pz, const Unit *px, const Unit *py); // y[2N] = x[N] * x[N] @@ -173,17 +173,17 @@ MCL_DLL_API void mulNM(Unit *z, const Unit *x, size_t xn, const Unit *y, size_t // explicit specialization of template functions and external asm functions #include "bint_proto.hpp" -template -void copyT(T *y, const T *x) +template +void copyT(T *y, const U *x) { - for (size_t i = 0; i < N; i++) y[i] = x[i]; + for (size_t i = 0; i < N; i++) y[i] = T(x[i]); } // y[n] = x[n] -template -void copyN(T *y, const T *x, size_t n) +template +void copyN(T *y, const U *x, size_t n) { - for (size_t i = 0; i < n; i++) y[i] = x[i]; + for (size_t i = 0; i < n; i++) y[i] = T(x[i]); } template diff --git a/src/bint_impl.hpp b/src/bint_impl.hpp index 1145893e..10c07100 100644 --- a/src/bint_impl.hpp +++ b/src/bint_impl.hpp @@ -104,8 +104,8 @@ Unit addT(Unit *z, const Unit *x, const Unit *y) #endif } -template -Unit subT(Unit *z, const Unit *x, const Unit *y) +template +Unit subT(Unit *z, const T *x, const Unit *y) { #if defined(MCL_WASM32) && MCL_SIZEOF_UNIT == 4 // wasm32 supports 64-bit sub @@ -164,30 +164,19 @@ Unit subNFT(Unit *z, const Unit *x, const Unit *y) } -template -Unit mulUnitT(Unit *z, const Unit *x, Unit y) +template +Unit mulUnitT(T *z, const Unit *x, Unit y) { #if MCL_SIZEOF_UNIT == 4 -#if 1 - uint64_t H = 0; +// use T as uint64_t to reduce conversion uint64_t y_ = y; - for (size_t i = 0; i < N; i++) { - uint64_t v = x[i] * y_; - v += H; - z[i] = uint32_t(v); - H = v >> 32; - } - return uint32_t(H); -#else - uint64_t H = 0; - for (size_t i = 0; i < N; i++) { - uint64_t v = x[i] * uint64_t(y); - v += H; + uint64_t v = x[0] * y_; + z[0] = uint32_t(v); + for (size_t i = 1; i < N; i++) { + v = x[i] * y_ + (v >> 32); z[i] = uint32_t(v); - H = v >> 32; } - return uint32_t(H); -#endif + return uint32_t(v >> 32); #elif defined(MCL_DEFINED_UINT128_T) uint64_t H = 0; for (size_t i = 0; i < N; i++) { @@ -211,21 +200,18 @@ Unit mulUnitT(Unit *z, const Unit *x, Unit y) #endif } -template -Unit mulUnitAddT(Unit *z, const Unit *x, Unit y) +template +Unit mulUnitAddT(T *z, const Unit *x, Unit y) { #if defined(MCL_WASM32) && MCL_SIZEOF_UNIT == 4 - // reduce cast operation - uint64_t H = 0; uint64_t y_ = y; - for (size_t i = 0; i < N; i++) { - uint64_t v = x[i] * y_; - v += H; - v += z[i]; + uint64_t v = z[0] + x[0] * y_; + z[0] = uint32_t(v); + for (size_t i = 1; i < N; i++) { + v = z[i] + x[i] * y_ + (v >> 32); z[i] = uint32_t(v); - H = v >> 32; } - return H; + return uint32_t(v >> 32); #else Unit xy[N], ret; ret = mulUnitT(xy, x, y); diff --git a/src/low_func.hpp b/src/low_func.hpp index 6f4ffda9..6cd44081 100644 --- a/src/low_func.hpp +++ b/src/low_func.hpp @@ -126,11 +126,11 @@ static void fpDblSubModT(Unit *z, const Unit *x, const Unit *y, const Unit *p) } // [return:z[N+1]] = z[N+1] + x[N] * y + (CF << (N * UnitBitSize)) -template -Unit mulUnitAddFullWithCF(Unit z[N + 1], const Unit x[N], Unit y, Unit CF) +template +Unit mulUnitAddFullWithCF(T z[N + 1], const Unit x[N], Unit y, Unit CF) { Unit H = bint::mulUnitAddT(z, x, y); - Unit v = z[N]; + T v = z[N]; v += H; Unit CF2 = v < H; v += CF; @@ -147,7 +147,11 @@ template static void modRedT(Unit *z, const Unit *xy, const Unit *p) { const Unit rp = p[-1]; +#if defined(MCL_WASM32) && MCL_SIZEOF_UNIT == 4 + uint64_t buf[N * 2]; +#else Unit buf[N * 2]; +#endif bint::copyT(buf, xy); Unit CF = 0; for (size_t i = 0; i < N; i++) { @@ -243,7 +247,12 @@ static void mulMontNFT(Unit *z, const Unit *x, const Unit *y, const Unit *p) t >> 64 <= (F - 2)(R - 1)/R = (F - 2) - (F - 2)/R t + (t >> 64) = (F - 2)R - (F - 2)/R < FR */ +#if defined(MCL_WASM32) && MCL_SIZEOF_UNIT == 4 + // use uint64_t if Unit = uint32_t to reduce conversion + uint64_t buf[N * 2]; +#else Unit buf[N * 2]; +#endif buf[N] = bint::mulUnitT(buf, x, y[0]); Unit q = buf[0] * rp; buf[N] += bint::mulUnitAddT(buf, p, q); From 4aa76f8ec5771a9cf04c1d40d89de25ee4f1abc7 Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Mon, 30 Sep 2024 16:33:55 +0900 Subject: [PATCH 2/2] reduce # of inv in LagrangeInterpolation --- include/mcl/lagrange.hpp | 32 ++++++++++++++++++++++++++++++++ include/mcl/operator.hpp | 2 +- test/bench.hpp | 24 ++++++++++++++++++++++++ test/bls12_test.cpp | 1 + 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/include/mcl/lagrange.hpp b/include/mcl/lagrange.hpp index 18e0597e..e163d27e 100644 --- a/include/mcl/lagrange.hpp +++ b/include/mcl/lagrange.hpp @@ -39,6 +39,37 @@ void LagrangeInterpolation(bool *pb, G& out, const F *S, const G *vec, size_t k) /* f(0) = sum_i f(S[i]) delta_{i,S}(0) */ +#if 1 + // reduce # of inv + // d[i] = S[i] prod_{j!=i}(S[j] - S[i]) + F *d = (F*)CYBOZU_ALLOCA(sizeof(F) * k); + for (size_t i = 0; i < k; i++) { + d[i] = S[i]; + } + for (size_t i = 0; i < k; i++) { + for (size_t j = 0; j < k; j++) { + if (j != i) { + F v; + F::sub(v, S[j], S[i]); + if (v.isZero()) { + *pb = false; + return; + } + d[i] *= v; + } + } + } + mcl::invVec(d, d, k); + G r; + d[0] *= a; + G::mul(r, vec[0], d[0]); + for (size_t i = 1; i < k; i++) { + d[i] *= a; + G t; + G::mul(t, vec[i], d[i]); + r += t; + } +#else G r; r.clear(); for (size_t i = 0; i < k; i++) { @@ -57,6 +88,7 @@ void LagrangeInterpolation(bool *pb, G& out, const F *S, const G *vec, size_t k) G::mul(t, vec[i], a / b); r += t; } +#endif out = r; *pb = true; } diff --git a/include/mcl/operator.hpp b/include/mcl/operator.hpp index 0a85f879..b1c168e2 100644 --- a/include/mcl/operator.hpp +++ b/include/mcl/operator.hpp @@ -86,7 +86,7 @@ size_t invVecWork(Tout& y, Tin& x, size_t n, T *t) x[i] returns i-th const T& */ template -size_t invVecT(Tout& y, Tin& x, size_t n, size_t N = 256) +size_t invVecT(Tout& y, Tin& x, size_t n, size_t N = 1024) { T *t = (T*)CYBOZU_ALLOCA(sizeof(T) * N); size_t retNum = 0; diff --git a/test/bench.hpp b/test/bench.hpp index 558e5f49..6b3193a4 100644 --- a/test/bench.hpp +++ b/test/bench.hpp @@ -285,4 +285,28 @@ void testLagrange() } } } + { + const int n = 50; + cybozu::XorShift rg; + const int k = 40; + Fr c[k]; + Fr x[n], y[n]; + for (size_t i = 0; i < k; i++) { + c[i].setByCSPRNG(rg); + } + for (size_t i = 0; i < n; i++) { + x[i].setByCSPRNG(rg); + mcl::evaluatePolynomial(y[i], c, k, x[i]); + } + Fr s; + bool b; + mcl::LagrangeInterpolation(&b, s, x, y, k); + CYBOZU_TEST_ASSERT(b); + CYBOZU_TEST_EQUAL(s, c[0]); +#ifndef NDEBUG + puts("lagrange bench skip in debug"); + return; +#endif + CYBOZU_BENCH_C("LagrangeInterpolation", 100, mcl::LagrangeInterpolation, &b, s, x, y, k); + } } diff --git a/test/bls12_test.cpp b/test/bls12_test.cpp index ded566c4..be6e78c7 100644 --- a/test/bls12_test.cpp +++ b/test/bls12_test.cpp @@ -401,6 +401,7 @@ CYBOZU_TEST_AUTO(naive) clk.put(); return; #endif + testLagrange(); testMulVec(); testSerialize(P, Q); testParam(ts);