From cc56c0e473a5611975bcb9fb7f5f57fa012551a6 Mon Sep 17 00:00:00 2001 From: Marius van der Wijden Date: Tue, 9 Jul 2019 14:08:54 +0200 Subject: [PATCH] fixed substract, added mnt code, renamed variables --- cuda/device_field.h | 345 ++++++++++++++++++++++++++-- cuda/device_field_operator_test.cpp | 154 +++++++------ cuda/device_field_operators.h | 177 ++++++-------- 3 files changed, 475 insertions(+), 201 deletions(-) diff --git a/cuda/device_field.h b/cuda/device_field.h index a2b8258..d6bf5af 100644 --- a/cuda/device_field.h +++ b/cuda/device_field.h @@ -32,7 +32,7 @@ #endif -#define SIZE (256 / 32) +#define SIZE (768 / 32) namespace fields{ @@ -41,60 +41,371 @@ using size_t = decltype(sizeof 1ll); #ifndef DEBUG __constant__ #endif -uint32_t _mod [SIZE]; +//decimal representation of mod + +uint32_t _mod [SIZE] = {115910, 764593169, 270700578, 4007841197, 3086587728, + 1536143341, 1589111033, 1821890675, 134068517, 3902860685, + 2580620505, 2707093405, 2971133814, 4061660573, 3087994277, + 3411246648, 1750781161, 1987204260, 1669861489, 2596546032, + 3818738770, 752685471, 1586521054, 610172929}; + +struct Scalar { + + cu_fun void add(Scalar & fld1, const Scalar & fld2) const; + cu_fun void mul(Scalar & fld1, const Scalar & fld2) const; + cu_fun void subtract(Scalar & fld1, const Scalar & fld2) const; + cu_fun static void pow(Scalar & fld1, const size_t pow); -struct Field { //Intermediate representation uint32_t im_rep [SIZE] = {0}; //Returns zero element - cu_fun static Field zero() + cu_fun static Scalar zero() { - Field res; + Scalar res; for(size_t i = 0; i < SIZE; i++) res.im_rep[i] = 0; return res; } //Returns one element - cu_fun static Field one() + cu_fun static Scalar one() { - Field res; - res.im_rep[SIZE - 1] = 1; + Scalar res; + res.im_rep[SIZE - 1] = 1; return res; } //Default constructor - Field() = default; + Scalar() = default; //Construct from value - cu_fun Field(const uint32_t value) + cu_fun Scalar(const uint32_t value) { im_rep[SIZE - 1] = value; } - cu_fun Field(const uint32_t* value) + cu_fun Scalar(const uint32_t* value) { for(size_t i = 0; i < SIZE; i++) im_rep[i] = value[i]; } -}; + //Returns true iff this element is zero + cu_fun bool is_zero() const + { + for(size_t i = 0; i < SIZE; i++) + if(this->im_rep[i] != 0) + return false; + return true; + } + + cu_fun Scalar operator*(const Scalar& rhs) const + { + Scalar s; + for(size_t i = 0; i < SIZE; i++) + s.im_rep[i] = this->im_rep[i]; + mul(s, rhs); + return s; + } + + cu_fun Scalar operator+(const Scalar& rhs) const + { + Scalar s; + for(size_t i = 0; i < SIZE; i++) + s.im_rep[i] = this->im_rep[i]; + add(s, rhs); + return s; + } + + cu_fun Scalar operator-(const Scalar& rhs) const + { + Scalar s; + for(size_t i = 0; i < SIZE; i++) + s.im_rep[i] = this->im_rep[i]; + subtract(s, rhs); + return s; + } + + cu_fun Scalar operator-() const + { + Scalar s; + for(size_t i = 0; i < SIZE; i++) + s.im_rep[i] = this->im_rep[i]; + subtract(s, *this); + return s; + } + + cu_fun Scalar square() const + { + Scalar s; + for(size_t i = 0; i < SIZE; i++) + s.im_rep[i] = this->im_rep[i]; + mul(s, *this); + return s; + } + + cu_fun static Scalar shuffle_down(unsigned mask, Scalar val, unsigned offset) + { + Scalar result; + for(size_t i = 0; i < SIZE; i++) +#if defined(__CUDA_ARCH__) + result.im_rep[i] = __shfl_down_sync(mask, val.im_rep[i], offset); +#else + result.im_rep[i] = val.im_rep[i]; +#endif + return result; + } #ifdef DEBUG - void printField(fields::Field f) + static void printScalar(Scalar f) { for(size_t i = 0; i < SIZE; i++) printf("%u, ", f.im_rep[i]); printf("\n"); } - void testEquality(fields::Field f1, fields::Field f2) + static void testEquality(Scalar f1, Scalar f2) { for(size_t i = 0; i < SIZE; i++) if(f1.im_rep[i] != f2.im_rep[i]) { - printField(f1); - printField(f2); + printScalar(f1); + printScalar(f2); assert(!"missmatch"); } } #endif +}; + +cu_fun long idxOfLNZ(const Scalar& fld); +cu_fun bool hasBitAt(const Scalar& fld, long index); + +struct fp2 { + Scalar x; + Scalar y; + const Scalar non_residue = Scalar(13); //13 for mnt4753 and 11 for mnt6753 + + fp2 () = default; + + cu_fun static fp2 zero() + { + fp2 res; + res.x = Scalar::zero(); + res.y = Scalar::zero(); + return res; + } + + cu_fun fp2(Scalar _x, Scalar _y) + { + x = _x; + y = _y; + } + + cu_fun fp2 operator*(const Scalar& rhs) const + { + return fp2(this->x * rhs, this->y * rhs); + } + + cu_fun fp2 operator*(const fp2& rhs) const + { + const Scalar &A = rhs.x; + const Scalar &B = rhs.y; + const Scalar &a = this->x; + const Scalar &b = this->y; + const Scalar aA = a * A; + const Scalar bB = b * B; + return fp2(aA + non_residue * bB, ((a+b) * (A+B) - aA) - bB); + } + + cu_fun fp2 operator-(const fp2& rhs) const + { + return fp2(this->x - rhs.x, this->y - rhs.y); + } + + cu_fun fp2 operator-() const + { + return fp2(-this->x, -this->y); + } + + cu_fun fp2 operator+(const fp2& rhs) const + { + return fp2(this->x + rhs.x, this->y + rhs.y); + } + + cu_fun void operator=(const fp2& rhs) + { + this->x = rhs.x; + this->y = rhs.y; + } + + cu_fun static fp2 shuffle_down(unsigned mask, fp2 val, unsigned offset) + { + fp2 result; + result.x = Scalar::shuffle_down(mask, val.x, offset); + result.y = Scalar::shuffle_down(mask, val.y, offset); + return result; + } +}; + +struct mnt4753_G1 { + Scalar x; + Scalar y; + Scalar z; + const Scalar coeff_a = Scalar(2); //2 for mnt4753 11 for mnt6753 + + cu_fun mnt4753_G1() { + x = Scalar::zero(); + y = Scalar::zero(); + z = Scalar::zero(); + } + + cu_fun mnt4753_G1(Scalar _x, Scalar _y, Scalar _z) + { + x = _x; + y = _y; + z = _z; + } + + cu_fun static bool is_zero(const mnt4753_G1& g1) { + return g1.x.is_zero() && g1.y.is_zero() && g1.z.is_zero(); + } + + cu_fun static mnt4753_G1 zero() + { + return mnt4753_G1(Scalar::zero(), Scalar::zero(), Scalar::zero()); + } + + cu_fun mnt4753_G1 operator+(const mnt4753_G1& other) const + { + const Scalar X1Z2 = this->x * other.z; + const Scalar Y1Z2 = this->y * other.z; + const Scalar Z1Z2 = this->z * other.z; + const Scalar u = other.y * this->z - Y1Z2; + const Scalar uu = u * u; + const Scalar v = other.x * this->z - X1Z2; + const Scalar vv = v * v; + const Scalar vvv = vv * v; + const Scalar R = vv * X1Z2; + const Scalar A = uu * Z1Z2 - (vvv + R + R); + const Scalar X3 = v * A; + const Scalar Y3 = u * (R-A) - vvv * Y1Z2; + const Scalar Z3 = vvv * Z1Z2; + return mnt4753_G1(X3, Y3, Z3); + } + + cu_fun mnt4753_G1 dbl() const + { + if (is_zero(*this)) + { + return (*this); + } + + const Scalar XX = this->x * this->x; // XX = X1^2 + const Scalar ZZ = this->z * this->z; // ZZ = Z1^2 + const Scalar w = mnt4753_G1::coeff_a * ZZ + (XX + XX + XX); // w = a*ZZ + 3*XX + const Scalar Y1Z1 = this->y * this->z; + const Scalar s = Y1Z1 + Y1Z1; // s = 2*Y1*Z1 + const Scalar ss = s * s; // ss = s^2 + const Scalar sss = s * ss; // sss = s*ss + const Scalar R = this->y * s; // R = Y1*s + const Scalar RR = R * R; // RR = R^2 + const Scalar T = this->x + R; + const Scalar TT = T * T; + const Scalar B = TT - XX - RR; // B = (X1+R)^2 - XX - RR + const Scalar h = (w * w) - (B+B); // h = w^2 - 2*B + const Scalar X3 = h * s; // X3 = h*s + const Scalar Y3 = w * (B-h)-(RR+RR); // Y3 = w*(B-h) - 2*RR + const Scalar Z3 = sss; // Z3 = sss + + return mnt4753_G1(X3, Y3, Z3); + } + + cu_fun void operator=(const mnt4753_G1& other) + { + this->x = other.x; + this->y = other.y; + this->z = other.z; + } + + cu_fun void operator+=(const mnt4753_G1& other) + { + *this = *this + other; + } + + cu_fun mnt4753_G1 operator-() const + { + return mnt4753_G1(this->x, -(this->y), this->z); + } + + cu_fun mnt4753_G1 operator-(const mnt4753_G1 &other) const + { + return (*this) + (-other); + } + + cu_fun mnt4753_G1 operator*(const Scalar &other) const + { + mnt4753_G1 result = zero(); + + bool one = false; + for (long i = idxOfLNZ(other) - 1; i >= 0; --i) + { + if (one) + result = result.dbl(); + if (hasBitAt(other,i)) + { + one = true; + result = result + *this; + } + } + return result; + } + + cu_fun static mnt4753_G1 shuffle_down(unsigned mask, mnt4753_G1 val, unsigned offset) + { + mnt4753_G1 result; + result.x = Scalar::shuffle_down(mask, val.x, offset); + result.y = Scalar::shuffle_down(mask, val.y, offset); + result.z = Scalar::shuffle_down(mask, val.z, offset); + return result; + } +}; + +} + +//Modular representation + +//mnt4753 mod: +//41898490967918953402344214791240637128170709919953949071783502921025352812571106773058893763790338921418070971888253786114353726529584385201591605722013126468931404347949840543007986327743462853720628051692141265303114721689601 +//mnt6753 +//41898490967918953402344214791240637128170709919953949071783502921025352812571106773058893763790338921418070971888458477323173057491593855069696241854796396165721416325350064441470418137846398469611935719059908164220784476160001 +//lg2(prime) = 752.8 -> 90 bytes to store -> 24 * 32bit = 768 / 32 + +//Binary representation +//00000000000000011100010011000110 +//00101101100100101100010000010001 +//00010000001000101001000000100010 +//11101110111000101100110110101101 +//10110111111110011001011101010000 +//01011011100011111010111111101101 +//01011110101101111110100011111001 +//01101100100101111101100001110011 +//00000111111111011011100100100101 +//11101000101000001110110110001101 +//10011001110100010010010011011001 +//10100001010110101111011110011101 +//10110001000101111110011101110110 +//11110010000110000000010110011101 +//10111000000011110000110110100101 +//11001011010100110111111000111000 +//01101000010110101100110011101001 +//01110110011100100101010010100100 +//01100011100010000001000001110001 +//10011010110001000010010111110000 +//11100011100111010101010001010010 +//00101100110111010001000110011111 +//01011110100100000110001111011110 +//00100100010111101000000000000001 -} \ No newline at end of file +//decimal representation +//= {115910, 764593169, 270700578, 4007841197, 3086587728, +// 1536143341, 1589111033, 1821890675, 134068517, 3902860685, +// 2580620505, 2707093405, 2971133814, 4061660573, 3087994277, +// 3411246648, 1750781161, 1987204260, 1669861489, 2596546032, +// 3818738770, 752685471, 1586521054, 610172929}; diff --git a/cuda/device_field_operator_test.cpp b/cuda/device_field_operator_test.cpp index 2e7ef3b..e9e6e1b 100644 --- a/cuda/device_field_operator_test.cpp +++ b/cuda/device_field_operator_test.cpp @@ -29,72 +29,74 @@ namespace fields{ + enum operand {add, substract, mul, pow}; + void testAdd() { printf("Addition test: "); - fields::Field f1(1234); - fields::Field f2(1234); - fields::Field result(2468); - add(f1,f2); - testEquality(f1, result); + fields::Scalar f1(1234); + fields::Scalar f2(1234); + fields::Scalar result(2468); + f1 = f1 + f2; + Scalar::testEquality(f1, result); printf("successful\n"); } - void testsubtract() + void test_subtract() { - printf("subtraction test: "); - fields::Field f1(1234); - fields::Field f2(1234); - subtract(f1,f2); - testEquality(f1, fields::Field::zero()); - fields::Field f3(1235); - subtract(f3, f2); - testEquality(f3,fields::Field::one()); + printf("_subtraction test: "); + fields::Scalar f1(1234); + fields::Scalar f2(1234); + f1 = f1 - f2; + Scalar::testEquality(f1, fields::Scalar::zero()); + fields::Scalar f3(1235); + f3 = f3 - f2; + Scalar::testEquality(f3,fields::Scalar::one()); printf("successful\n"); } void testMultiply() { printf("Multiply test: "); - fields::Field f1(1234); - fields::Field f2(1234); - mul(f1, f2); - testEquality(f1, fields::Field(1522756)); - mul(f1, f2); - testEquality(f1, fields::Field(1879080904)); - mul(f1, f2); - testEquality(f1, fields::Field(3798462992)); - fields::Field f3(1234); - square(f3); - testEquality(f3, fields::Field(1522756)); + fields::Scalar f1(1234); + fields::Scalar f2(1234); + f1 = f1 * f2; + Scalar::testEquality(f1, fields::Scalar(1522756)); + f1 = f1 * f2; + Scalar::testEquality(f1, fields::Scalar(1879080904)); + f1 = f1 * f2; + Scalar::testEquality(f1, fields::Scalar(3798462992)); + fields::Scalar f3(1234); + f3 = f3 * f3; + Scalar::testEquality(f3, fields::Scalar(1522756)); printf("successful\n"); } void testModulo() { printf("Modulo test: "); - fields::Field f1(uint32_t(0)); - fields::Field f2(1234); + fields::Scalar f1(uint32_t(0)); + fields::Scalar f2(1234); - fields::Field f3(); + fields::Scalar f3(); printf("successful\n"); } void testPow() { - printf("POW test: "); - fields::Field f1(2); - pow(f1, 0); - testEquality(f1, fields::Field::one()); - fields::Field f2(2); - pow(f2, 2); - testEquality(f2, fields::Field(4)); - pow(f2, 10); - testEquality(f2, fields::Field(1048576)); - fields::Field f3(2); - fields::Field f4(1048576); - pow(f3, 20); - testEquality(f3, f4); + printf("Scalar::pow test: "); + fields::Scalar f1(2); + Scalar::pow(f1, 0); + Scalar::testEquality(f1, fields::Scalar::one()); + fields::Scalar f2(2); + Scalar::pow(f2, 2); + Scalar::testEquality(f2, fields::Scalar(4)); + Scalar::pow(f2, 10); + Scalar::testEquality(f2, fields::Scalar(1048576)); + fields::Scalar f3(2); + fields::Scalar f4(1048576); + Scalar::pow(f3, 20); + Scalar::testEquality(f3, f4); printf("successful\n"); } @@ -102,20 +104,23 @@ namespace fields{ void testConstructor() { printf("Constructor test: "); - fields::Field f3(1); - testEquality(f3, fields::Field::one()); - fields::Field f4; - testEquality(f4, fields::Field::zero()); - fields::Field f5(uint32_t(0)); - testEquality(f5, fields::Field::zero()); + fields::Scalar f3(1); + Scalar::testEquality(f3, fields::Scalar::one()); + fields::Scalar f4; + Scalar::testEquality(f4, fields::Scalar::zero()); + fields::Scalar f5(uint32_t(0)); + Scalar::testEquality(f5, fields::Scalar::zero()); - fields::Field f1; - fields::Field f2(1234); - add(f1, fields::Field(1234)); - testEquality(f1, f2); - uint32_t tmp [SIZE] ={0,0,0,0,0,0,0,1234}; - fields::Field f6(tmp); - testEquality(f6, f2); + fields::Scalar f1; + fields::Scalar f2(1234); + f1 = f1 + fields::Scalar(1234); + Scalar::testEquality(f1, f2); + uint32_t tmp [SIZE]; + for(int i = 0; i < SIZE; i++) + tmp[i] = 0; + tmp[SIZE -1 ] = 1234; + fields::Scalar f6(tmp); + Scalar::testEquality(f6, f2); printf("successful\n"); } @@ -123,28 +128,25 @@ namespace fields{ void setMod() { - assert(SIZE == 8); - _mod[0] = 0; - _mod[1] = 0; - _mod[2] = 0; - _mod[3] = 0; - _mod[4] = 0; - _mod[5] = 0; - _mod[6] = 1; - _mod[7] = 0; + assert(SIZE == 24); + for(int i = 0; i < SIZE; i ++) + { + _mod[i] = 0; + } + _mod[SIZE - 2] = 1; } - void operate(fields::Field & f1, fields::Field const f2, int const op) + void operate(fields::Scalar & f1, fields::Scalar const f2, int const op) { switch(op){ case 0: - add(f1,f2); break; + f1 = f1 + f2; break; case 1: - subtract(f1,f2); break; + f1 = f1 - f2; break; case 2: - mul(f1,f2); break; + f1 = f1 * f2; break; case 3: - pow(f1, (f2.im_rep[SIZE - 1] & 65535)); + Scalar::pow(f1, (f2.im_rep[SIZE - 1] & 65535)); break; default: break; } @@ -177,13 +179,13 @@ namespace fields{ } } - void toMPZ(mpz_t ret, fields::Field f) + void toMPZ(mpz_t ret, fields::Scalar f) { mpz_init(ret); mpz_import(ret, SIZE, 1, sizeof(uint32_t), 0, 0, f.im_rep); } - void compare(fields::Field f1, fields::Field f2, mpz_t mpz1, mpz_t mpz2, mpz_t mod, int op) + void compare(fields::Scalar f1, fields::Scalar f2, mpz_t mpz1, mpz_t mpz2, mpz_t mod, int op) { mpz_t tmp1; mpz_init_set(tmp1, mpz1); @@ -192,8 +194,10 @@ namespace fields{ mpz_t tmp; toMPZ(tmp, f1); if(mpz_cmp(tmp, mpz1) != 0){ - gmp_printf ("t: %d [%Zd] : [%Zd] : %d\n",omp_get_thread_num(), tmp1, mpz2, op); - gmp_printf ("t: %d [%Zd] : [%Zd] \n",omp_get_thread_num() , mpz1, tmp); + printf("Missmatch: "); + gmp_printf ("t: %d [%Zd] %d [%Zd] \n",omp_get_thread_num(), tmp1, op, mpz2); + gmp_printf ("t: %d CPU: [%Zd] GPU: [%Zd] \n",omp_get_thread_num() , mpz1, tmp); + Scalar::printScalar(f1); assert(!"error"); } mpz_clear(tmp1); @@ -241,13 +245,13 @@ namespace fields{ mpz_init(mod); mpz_set_ui(mod, 4294967296); mpz_set_ui(b, i); - fields::Field f2(i); + fields::Scalar f2(i); for(size_t k = 0; k < 4294967295; k = k + k_step) { for(size_t z = 0; z <= 3; z++ ) { mpz_set_ui(a, k); - fields::Field f1(k); + fields::Scalar f1(k); compare(f1,f2,a,b,mod,z); } } @@ -266,7 +270,7 @@ int main(int argc, char** argv) fields::setMod(); fields::testConstructor(); fields::testAdd(); - fields::testsubtract(); + fields::test_subtract(); fields::testMultiply(); fields::testPow(); fields::fuzzTest(); diff --git a/cuda/device_field_operators.h b/cuda/device_field_operators.h index 44ce462..aa540e7 100644 --- a/cuda/device_field_operators.h +++ b/cuda/device_field_operators.h @@ -21,7 +21,13 @@ #include "device_field.h" -#define SIZE (256 / 32) +#define SIZE (768 / 32) + +#if defined(__CUDA_ARCH__) +#define _clz __clz +#else +#define _clz __builtin_clz +#endif #ifndef DEBUG #define cu_fun __host__ __device__ @@ -32,11 +38,13 @@ #include #endif +#define CHECK_BIT(var,pos) ((var) & (1<<(pos))) + namespace fields{ using size_t = decltype(sizeof 1ll); -cu_fun bool operator==(const Field& lhs, const Field& rhs) +cu_fun bool operator==(const Scalar& lhs, const Scalar& rhs) { for(size_t i = 0; i < SIZE; i++) if(lhs.im_rep[i] != rhs.im_rep[i]) @@ -44,42 +52,59 @@ cu_fun bool operator==(const Field& lhs, const Field& rhs) return true; } -//Returns true iff this element is zero -cu_fun bool is_zero(const Field & fld) +cu_fun uint32_t clz(const uint32_t* element, const size_t e_size) { - for(size_t i = 0; i < SIZE; i++) - if(fld.im_rep[i] != 0) - return false; - return true; + uint32_t lz = 0; + uint32_t tmp; + for(size_t i = 0; i < e_size; i++) + { + if(element[i] == 0) + tmp = 32; + else + tmp = _clz(element[i]); + lz += tmp; + if(tmp < 32) + break; + } + return lz; } -cu_fun void set_mod(const Field& f) +cu_fun long idxOfLNZ(const Scalar& fld) +{ + return SIZE - clz(fld.im_rep, SIZE); +} + +cu_fun bool hasBitAt(const Scalar& fld, long index) +{ + long idx1 = index % SIZE; + long idx2 = index / SIZE; + return CHECK_BIT(fld.im_rep[idx2], idx1) != 0; +} + +#ifdef DEBUG +cu_fun void set_mod(const Scalar& f) { for(size_t i = 0; i < SIZE; i++) _mod[i] = f.im_rep[i]; } +#endif //Returns true if the first element is less than the second element cu_fun bool less(const uint32_t* element1, const size_t e1_size, const uint32_t* element2, const size_t e2_size) { - assert(e1_size >= e2_size); - size_t diff = e1_size - e2_size; - for(size_t i = 0; i < diff; i++) - if(element1[i] > 0) - return false; - for(size_t i = 0; i < e2_size; i++) - if(element1[i + diff] > element2[i]) + assert(e1_size == e2_size); + for(size_t i = 0; i < SIZE; i++) + if(element1[i] > element2[i]) return false; - else if(element1[i + diff] < element2[i]) + else if(element1[i] < element2[i]) return true; return false; } // Returns the carry, true if there was a carry, false otherwise -// Takes a sign, true if negative -cu_fun bool add(bool sign, uint32_t* element1, const size_t e1_size, const uint32_t* element2, const size_t e2_size) +cu_fun bool _add(uint32_t* element1, const size_t e1_size, const uint32_t* element2, const size_t e2_size) { - assert(e1_size >= e2_size); + assert(e1_size == e2_size); bool carry = false; for(size_t i = 1; i <= e1_size; i++) { @@ -88,29 +113,28 @@ cu_fun bool add(bool sign, uint32_t* element1, const size_t e1_size, const uint3 element1[e1_size - i] = tmp + (uint64_t)element2[e1_size - i]; carry = (tmp >> 32) > 0; } - return sign? 0: carry; + return carry; } -// Returns the carry, true if the resulting number is negative -cu_fun bool subtract(uint32_t* element1, const size_t e1_size, bool carry, const uint32_t* element2, const size_t e2_size) +// Fails if the second number is bigger than the first +cu_fun void _subtract(uint32_t* element1, const size_t e1_size, bool carry, const uint32_t* element2, const size_t e2_size) { - assert(e1_size >= e2_size); + assert(e1_size == e2_size); + bool borrow = false; for(size_t i = 1; i <= e1_size; i++) { uint64_t tmp = (uint64_t)element1[e1_size - i]; - bool underflow = (tmp == 0); - if(carry) tmp--; - carry = (e2_size >= i) ? (tmp < element2[e2_size - i]) : underflow; - if(carry) tmp += ((uint64_t)1 << 33); + if(borrow) tmp--; + borrow = (tmp < element2[e2_size - i]); + if(borrow) tmp += ((uint64_t)1 << 33); element1[e1_size - i] = tmp - element2[e2_size - i]; } - return carry; + assert(carry == borrow); } cu_fun void ciosMontgomeryMultiply(uint32_t * result, const uint32_t* a, const size_t a_size, -const uint32_t* b, const size_t b_size, -const uint32_t* n, const size_t n_size, +const uint32_t* b, const uint32_t* n, const uint64_t m_prime) { uint64_t temp; @@ -169,106 +193,40 @@ const uint64_t m_prime) memcpy(result, t, a_size); } -cu_fun void modulo(uint32_t* element, const size_t e_size, const uint32_t* mod, const size_t mod_size, bool carry) -{ - if(less(element, e_size, mod, mod_size)) - return; - printf("tick"); - - uint32_t tmp[SIZE * 2]; - memset(tmp, 0, (SIZE * 2) * sizeof(uint32_t)); - - ciosMontgomeryMultiply(tmp + 1, Field::one().im_rep, SIZE, element, SIZE, _mod, SIZE, 4294967296L); - for(size_t i = 0; i < SIZE; i++) - element[i] = tmp[i]; -} - -cu_fun bool multiply(uint32_t * result, const uint32_t* element1, const size_t e1_size, const uint32_t* element2, const size_t e2_size) -{ - bool carry = false; - uint64_t temp; - for(size_t i = 0; i < e2_size; i++) - { - uint32_t carry = 0; - for(size_t j = 0; j < e1_size; j++) - { - temp = result[i+j + 1]; - temp += (uint64_t)element1[j] * (uint64_t)element2[i]; - temp += carry; - result[i+j + 1] = (uint32_t)temp; - carry = temp >> 32; - } - result[i + e1_size + 1] = carry; - } - return carry; -} - -//Squares this element -cu_fun void square(Field & fld) -{ - //TODO since squaring produces equal intermediate results, this can be sped up - uint32_t tmp[SIZE * 2]; - memset(tmp, 0, SIZE * 2 * sizeof(uint32_t)); - bool carry = multiply(tmp, fld.im_rep, SIZE, fld.im_rep, SIZE); - //size of tmp is 2*size - modulo(tmp, 2*SIZE, _mod, SIZE, carry); - //Last size words are the result - for(size_t i = 0; i < SIZE; i++) - fld.im_rep[i] = tmp[SIZE + i]; -} - -/* -//Doubles this element -void double(Field & fld) -{ - uint32_t temp[] = {2}; - uint32_t tmp[] = multiply(fld.im_rep, size, temp, 1); - //size of tmp is 2*size - modulo(tmp, 2*size, mod, size); - //Last size words are the result - for(size_t i = 0; i < size; i++) - fld.im_rep[i] = tmp[size + i]; -}*/ - -//Negates this element -cu_fun void negate(Field & fld) -{ - //TODO implement -} - //Adds two elements -cu_fun void add(Field & fld1, const Field & fld2) +cu_fun void Scalar::add(Scalar & fld1, const Scalar & fld2) const { - bool carry = add(false, fld1.im_rep, SIZE, fld2.im_rep, SIZE); + bool carry = _add(fld1.im_rep, SIZE, fld2.im_rep, SIZE); if(carry || less(_mod, SIZE, fld1.im_rep, SIZE)) - subtract(fld1.im_rep, SIZE, false, _mod, SIZE); + _subtract(fld1.im_rep, SIZE, false, _mod, SIZE); } //Subtract element two from element one -cu_fun void subtract(Field & fld1, const Field & fld2) +cu_fun void Scalar::subtract(Scalar & fld1, const Scalar & fld2) const { + bool carry = false; if(less(fld1.im_rep, SIZE, fld2.im_rep, SIZE)) - add(true, fld1.im_rep, SIZE, _mod, SIZE); - subtract(fld1.im_rep, SIZE, false, fld2.im_rep, SIZE); + carry = _add(fld1.im_rep, SIZE, _mod, SIZE); + _subtract(fld1.im_rep, SIZE, carry, fld2.im_rep, SIZE); } //Multiply two elements -cu_fun void mul(Field & fld1, const Field & fld2) +cu_fun void Scalar::mul(Scalar & fld1, const Scalar & fld2) const { uint32_t tmp[SIZE * 2]; memset(tmp, 0, (SIZE * 2) * sizeof(uint32_t)); - ciosMontgomeryMultiply(tmp + 1, fld1.im_rep, SIZE, fld2.im_rep, SIZE, _mod, SIZE, 4294967296L); + ciosMontgomeryMultiply(tmp + 1, fld1.im_rep, SIZE, fld2.im_rep, _mod, 4294967296L); for(size_t i = 0; i < SIZE; i++) fld1.im_rep[i] = tmp[i]; } //Exponentiates this element -cu_fun void pow(Field & fld1, const size_t pow) +cu_fun void Scalar::pow(Scalar & fld1, const size_t pow) { if(pow == 0) { - fld1 = Field::one(); + fld1 = Scalar::one(); return; } @@ -287,9 +245,10 @@ cu_fun void pow(Field & fld1, const size_t pow) { memset(tmp, 0, (SIZE * 2) * sizeof(uint32_t)); - ciosMontgomeryMultiply(tmp + 1, fld1.im_rep, SIZE, temp, SIZE, _mod, SIZE, 4294967296L); + ciosMontgomeryMultiply(tmp + 1, fld1.im_rep, SIZE, temp, _mod, 4294967296L); for(size_t i = 0; i < SIZE; i++) fld1.im_rep[i] = tmp[i]; } } + }