Skip to content

Commit

Permalink
Merge pull request #6 from mitschabaude/bls12-377
Browse files Browse the repository at this point in the history
Bls12 377
  • Loading branch information
mitschabaude authored Dec 3, 2023
2 parents ad0d061 + a27d186 commit ae33382
Show file tree
Hide file tree
Showing 14 changed files with 581 additions and 64 deletions.
54 changes: 54 additions & 0 deletions scripts/test-bls12377-msm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import "../src/extra/fix-webcrypto.js";
import { tic, toc } from "../src/extra/tictoc.js";
import {
msm,
msmUtil,
Field,
CurveAffine,
Random,
Scalar,
Bigint,
} from "../src/concrete/bls12-377.js";
import { bigintScalarsToMemory } from "../src/msm.js";
import { checkOnCurve, msmDumbAffine } from "../src/extra/dumb-curve-affine.js";
import assert from "node:assert/strict";

let n = Number(process.argv[2] ?? 8);
let N = 1 << n;
console.log(`running msm with 2^${n} = ${2 ** n} inputs`);

tic("random points");
let points = Field.getZeroPointers(N, CurveAffine.sizeAffine);
let scratch = Field.getPointers(40);
CurveAffine.randomPoints(scratch, points);

let scalars = Random.randomScalars(N);
let scalarPtr = bigintScalarsToMemory(Scalar, scalars);
toc();

tic("convert points to bigint & check");
let pointsBigint = points.map((gPtr) => {
let g = CurveAffine.toBigint(gPtr);
assert(checkOnCurve(g, Field.p, CurveAffine.b), "point on curve");
return g;
});
toc();

tic("msm (core)");
let s0 = msm(scalarPtr, points[0], N);
let s = msmUtil.toAffineOutputBigint(scratch, s0);
toc();

tic("msm (bigint)");
let s1 = Bigint.msm(scalars, pointsBigint);
toc();

assert.deepEqual(s, s1, "consistent with bigint version");

if (n < 12) {
tic("msm (simple, slow bigint impl)");
let sBigint = msmDumbAffine(scalars, pointsBigint, Scalar, Field);
toc();
assert.deepEqual(s, sBigint, "consistent results");
console.log("results are consistent!");
}
212 changes: 212 additions & 0 deletions scripts/test-bls12377.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// run with ts-node-esm
import "../src/extra/fix-webcrypto.js";
import { BLS12_377 } from "../src/index.js";
import {
assert,
bigintToBits,
extractBitSlice as extractBitSliceJS,
} from "../src/util.js";
import { mod, modExp, modInverse } from "../src/field-util.js";
import { G, q } from "../src/concrete/bls12-377.params.js";

let { Field, Scalar, CurveAffine, Random } = BLS12_377;
const { p } = Field;

function toWasm(x0: bigint, x: number) {
Field.writeBigint(x, x0);
Field.toMontgomery(x);
}
function ofWasm([tmp]: number[], x: number) {
Field.multiply(tmp, x, Field.constants.one);
Field.reduce(tmp);
return mod(Field.readBigint(tmp), p);
}

let [x, y, z, z_hi, ...scratch] = Field.getPointers(40);
let scratchScalar = Scalar.getPointers(10);

let R = mod(1n << BigInt(Field.w * Field.n), p);
let Rinv = modInverse(R, p);

function test() {
let x0 = Random.randomFieldx2();
let y0 = Random.randomFieldx2();
toWasm(x0, x);
toWasm(y0, y);

// multiply
let z0 = mod(x0 * y0, p);
Field.multiply(z, x, y);
let z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("multiply");
z0 = 0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffffn; // test overflow resistance
toWasm(z0, z);
z0 = mod(z0 * z0, p);
Field.multiply(z, z, z);
z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("multiply");

// square
z0 = mod(x0 * x0, p);
Field.square(z, x);
z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("square");

// leftShift
let k = 97;
z0 = 1n << BigInt(k);
// computes R^2 * 2^k / R = 2^k R, which is 2^k in montgomery form
Field.leftShift(z, Field.constants.R2, k);
z1 = ofWasm(scratch, z);
if (z1 !== z0) throw Error("leftShift");

// add
z0 = mod(x0 + y0, p);
Field.add(z, x, y);
z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("add");

// subtract
z0 = mod(x0 - y0, p);
Field.subtract(z, x, y);
z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("subtract");

// subtract plus 2p
z0 = mod(x0 - y0, p);
Field.subtractPositive(z, x, y);
z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("subtract");

// reduceInPlace
z0 = x0 >= p ? x0 - p : x0;
Field.copy(z, x);
Field.reduce(z);
z1 = ofWasm(scratch, z);
if (z0 !== z1) throw Error("reduceInPlace");

// isEqual
if (Field.isEqual(x, x) !== 1) throw Error("isEqual");
if (Field.isEqual(x, y) !== 0) throw Error("isEqual");

// inverse
Field.inverse(scratch[0], z, x);
Field.multiply(z, z, x);
z1 = ofWasm(scratch, z);
if (z1 !== 1n) throw Error("inverse");
z0 = modInverse(x0, p);
if (mod(z0 * x0, p) !== 1n) throw Error("inverse");

// sqrt
let exists0 = modExp(x0, (p - 1n) >> 1n, { p }) === 1n;
let exists1 = Field.sqrt(scratch, z, x);
if (exists0 !== exists1) throw Error("isSquare");
if (exists0) {
let zsqrt = ofWasm(scratch, z);
Field.square(z, z);
z0 = ofWasm(scratch, z);
if (mod(zsqrt * zsqrt - x0, p) !== 0n) throw Error("sqrt");
if (mod(z0 - x0, p) !== 0n) throw Error("sqrt");
}

// roots
let minus1 = Field.roots.at(-1)!;
if (Field.toBigint(minus1) !== p - 1n) throw Error("roots");

// makeOdd
Field.writeBigint(x, 5n << 120n);
Field.writeBigint(z, 3n);
Field.makeOdd(x, z);
x0 = Field.readBigint(x);
z0 = Field.readBigint(z);
if (!(x0 === 5n && z0 === 3n << 120n)) throw Error("makeOdd");

// extractBitSlice
let arr = new Uint8Array([0b0010_0110, 0b1101_0101, 0b1111_1111]);
let e = Error("extractBitSlice");
if (extractBitSliceJS(arr, 2, 4) !== 0b10_01) throw e;
if (extractBitSliceJS(arr, 0, 2) !== 0b10) throw e;
if (extractBitSliceJS(arr, 0, 8) !== 0b0010_0110) throw e;
if (extractBitSliceJS(arr, 3, 9) !== 0b0101_0010_0) throw e;
if (extractBitSliceJS(arr, 8, 8) !== 0b1101_0101) throw e;
if (extractBitSliceJS(arr, 5, 3 + 8 + 2) !== 0b11_1101_0101_001) throw e;
if (extractBitSliceJS(arr, 16, 10) !== 0b1111_1111) throw e;

// extractBitSlice (wasm)
let s = Scalar.getPointer();
Scalar.writeBytes(scratchScalar, s, arr);
const { extractBitSlice } = Scalar;
e = Error("extractBitSlice (wasm)");
if (extractBitSlice(s, 2, 4) !== 0b10_01) throw e;
if (extractBitSlice(s, 0, 2) !== 0b10) throw e;
if (extractBitSlice(s, 0, 8) !== 0b0010_0110) throw e;
if (extractBitSlice(s, 3, 9) !== 0b0101_0010_0) throw e;
if (extractBitSlice(s, 8, 8) !== 0b1101_0101) throw e;
if (extractBitSlice(s, 5, 3 + 8 + 2) !== 0b11_1101_0101_001) throw e;
if (extractBitSlice(s, 16, 10) !== 0b1111_1111) throw e;
}

for (let i = 0; i < 100; i++) {
test();
testCurve();
}
for (let i = 0; i < 100; i++) {
let ok = Scalar.testDecomposeScalar(Random.randomScalar());
if (!ok) throw Error("scalar decomposition");
}

testBatchMontgomery();

function testBatchMontgomery() {
let n = 1000;
let X = Field.getPointers(n);
let invX = Field.getPointers(n);
let scratch = Field.getPointers(10);
for (let i = 0; i < n; i++) {
let x0 = Random.randomFieldx2();
Field.writeBigint(X[i], x0);
// compute inverses normally
Field.inverse(scratch[0], invX[i], X[i]);
}
// compute inverses as batch
let invX1 = Field.getPointers(n);
Field.batchInverse(scratch[0], invX1[0], X[0], n);

// check that all inverses are equal
for (let i = 0; i < n; i++) {
let z0 = Field.readBigint(invX[i]);
let z1 = Field.readBigint(invX1[i]);
if (mod(z1 - z0, p) !== 0n) throw Error("batch inverse");

Field.reduce(invX1[i]);
Field.reduce(invX[i]);
if (!Field.isEqual(invX1[i], invX[i])) {
console.log({
i,
z0,
z1,
invX0: Field.readBigint(invX[i]),
invX1: Field.readBigint(invX1[i]),
});
throw Error("batch inverse / reduce");
}
}
}

function testCurve() {
// prepare inputs
let [g, qG] = Field.getPointers(2, CurveAffine.sizeAffine);
CurveAffine.writeBigint(g, G);
let qBits = bigintToBits(q);

// scale and check
CurveAffine.scale(scratch, qG, g, qBits);
assert(CurveAffine.isZeroAffine(qG), "order*G = 0");

// create random point and check if it is in the subgroup
let [r, qR] = Field.getPointers(2, CurveAffine.sizeAffine);
CurveAffine.randomPoints(scratch, [r]);
assert(!CurveAffine.isZeroAffine(r), "random point R is not zero");
CurveAffine.scale(scratch, qR, r, qBits);
assert(CurveAffine.isZeroAffine(qR), "order*h*R = 0");
}
54 changes: 54 additions & 0 deletions src/concrete/bls12-377.params.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { scale } from "../extra/dumb-curve-affine.js";
import { mod, modExp, modInverse } from "../field-util.js";
import { assert, bigintToBits } from "../util.js";

export { p, q, h, b, lambda, beta, nBits, nBytes, G };

const p =
0x01ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001n;
const q = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001n;

// curve equation is y^2 = x^3 + 1
const b = 1n;

const nBits = 377;
const nBytes = 48;

// cofactor
const h = 0x170b5d44300000000000000000000000n;

// generator
const G = {
x: 0x008848defe740a67c8fc6225bf87ff5485951e2caa9d41bb188282c8bd37cb5cd5481512ffcd394eeab9b16eb21be9efn,
y: 0x01914a69c5102eff1f674f5d30afeec4bd7fb348ca3e52d96d182ad44fb82305c2fe3d3634a9591afd82de55559c8ea6n,
isInfinity: false,
};

const lambda =
0x12ab655e9a2ca55660b44d1e5c37b00114885f32400000000000000000000000n;
const beta =
0x1ae3a4617c510eabc8756ba8f8c524eb8882a75cc9bc8e359064ee822fb5bffd1e945779fffffffffffffffffffffffn;

const debug = false;

if (debug) {
// compute cube root in Fq (endo scalar) as lambda = x^(q - 1)/3 for some small x
const lambda_ = modExp(11n, (q - 1n) / 3n, { p: q });

assert(lambda === lambda_, "lambda is correct");
const lambda2 = mod(lambda * lambda, q);
assert(mod(lambda * lambda2, q) === 1n, "lambda is a cube root");

// compute beta such that lambda * (x, y) = (beta * x, y) (endo base)
let lambdaBits = bigintToBits(lambda, 256);
let lambdaG = scale(lambdaBits, G, p);
assert(lambdaG.y === G.y, "multiplication by lambda is a cheap endomorphism");

const beta_ = mod(lambdaG.x * modInverse(G.x, p), p);
assert(beta === beta_, "beta is correct");
assert(modExp(beta, 3n, { p }) === 1n, "beta is a cube root");

// note: since both phi1: p -> lambda*p and phi2: p -> (beta*p.x, p.y) are homomorphisms (easy to check),
// and they agree on a single point, they must agree on all points in the same subgroup:
// (phi1 - phi2)(s*G) = s*(phi1 - pgi2)(G) = 0
}
42 changes: 42 additions & 0 deletions src/concrete/bls12-377.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import type * as W from "wasmati";
import { randomGenerators } from "../field-util.js";
import { p, q, b, beta, lambda, h } from "./bls12-377.params.js";
import { createMsmField } from "../field-msm.js";
import { createGeneralGlvScalar } from "../scalar-glv.js";
import { createCurveAffine } from "../curve-affine.js";
import { createCurveProjective } from "../curve-projective.js";
import { createMsm } from "../msm.js";
import { createBigintApi } from "../bigint.js";

export { Bigint, Field, Scalar, CurveAffine, CurveProjective, Random };
export { msm, msmUnsafe, msmUtil };

const Field = await createMsmField(p, beta, 29);
const Scalar = await createGeneralGlvScalar(q, lambda, 29);
const CurveProjective = createCurveProjective(Field, h);
const CurveAffine = createCurveAffine(Field, CurveProjective, b);

const { msm, msmUnsafe, msmBigint, ...msmUtil } = createMsm({
Field,
Scalar,
CurveAffine,
CurveProjective,
});

const { randomField: randomScalar, randomFields: randomScalars } =
randomGenerators(q);

const Random = { ...randomGenerators(p), randomScalar, randomScalars };

const Bigint_ = createBigintApi({
Field,
Scalar,
CurveAffine,
CurveProjective,
});
const Bigint = {
...Bigint_,
msm: msmBigint,
randomFields: Random.randomFields,
randomScalars,
};
2 changes: 1 addition & 1 deletion src/concrete/bls12-381.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ export { msm, msmUnsafe, msmUtil };

const Field = await createMsmField(p, beta, 30);
const Scalar = await createGeneralGlvScalar(q, lambda, 29);
const CurveAffine = createCurveAffine(Field, 4n);
const CurveProjective = createCurveProjective(Field);
const CurveAffine = createCurveAffine(Field, CurveProjective, 4n);

const SpecialScalar = await createGlvScalar(q, lambda, 29);

Expand Down
2 changes: 1 addition & 1 deletion src/concrete/pasta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ export { msm, msmUnsafe, msmUtil };

const Field = await createMsmField(p, beta, 29);
const Scalar = await createGeneralGlvScalar(q, lambda, 29);
const CurveAffine = createCurveAffine(Field, b);
const CurveProjective = createCurveProjective(Field);
const CurveAffine = createCurveAffine(Field, CurveProjective, b);

const { msm, msmUnsafe, msmBigint, ...msmUtil } = createMsm({
Field,
Expand Down
Loading

0 comments on commit ae33382

Please sign in to comment.