Skip to content

Commit

Permalink
add parallel g1/g2 msm gnark-crypto impl (#217)
Browse files Browse the repository at this point in the history
* add parallel g1/g2 msm gnark-crypto impl
* add a configurable NbTasks for degree-of-parallelism for msm

Signed-off-by: garyschulte <[email protected]>
  • Loading branch information
garyschulte authored Oct 7, 2024
1 parent e788032 commit 7581d1e
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 14 deletions.
125 changes: 115 additions & 10 deletions gnark/gnark-jni/gnark-eip-2537.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"math/big"
"reflect"
"unsafe"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

const (
Expand Down Expand Up @@ -167,6 +169,54 @@ func eip2537blsG1MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn
return nonMontgomeryMarshalG1(result, javaOutputBuf, errorBuf)
}

//export eip2537blsG1MultiExpParallel
func eip2537blsG1MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int, nbTasks C.int) C.int {
inputLen := int(cInputLen)
errorLen := int(cOutputLen)

// Convert error C pointers to Go slices
errorBuf := castBuffer(javaErrorBuf, errorLen)

if inputLen == 0 {
copy(errorBuf, "invalid input parameters, invalid number of pairs\x00")
return 1
}

if inputLen % (EIP2537PreallocateForG1 + EIP2537PreallocateForScalar) != 0 {
copy(errorBuf, "invalid input parameters, invalid input length for G1 multiplication\x00")
return 1
}

// Convert input C pointers to Go slice
input := castBufferToSlice(unsafe.Pointer(javaInputBuf), inputLen)

var exprCount = inputLen / (EIP2537PreallocateForG1 + EIP2537PreallocateForScalar)

g1Points := make([]bls12381.G1Affine, exprCount)
scalars := make([]fr.Element, exprCount)

for i := 0 ; i < exprCount ; i++ {
_, err := g1AffineDecodeInSubGroupVal(&g1Points[i], input[i*160 : (i*160)+128])
if err != nil {
copy(errorBuf, err.Error())
return 1
}

scalars[i].SetBytes(input[(i*160)+128 : (i+1)*160])
}

var affineResult bls12381.G1Affine
// leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit
_, err := affineResult.MultiExp(g1Points, scalars, ecc.MultiExpConfig{NbTasks: int(nbTasks)})
if err != nil {
copy(errorBuf, err.Error())
return 1
}

// marshal the resulting point and encode directly to the output buffer
return nonMontgomeryMarshalG1(&affineResult, javaOutputBuf, errorBuf)
}

//export eip2537blsG2Add
func eip2537blsG2Add(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int {
inputLen := int(cInputLen)
Expand Down Expand Up @@ -289,6 +339,58 @@ func eip2537blsG2MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn
return nonMontgomeryMarshalG2(result, javaOutputBuf, errorBuf)
}

//export eip2537blsG2MultiExpParallel
func eip2537blsG2MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int, nbTasks C.int) C.int {
inputLen := int(cInputLen)
errorLen := int(cOutputLen)

// Convert error C pointers to Go slices
errorBuf := castBuffer(javaErrorBuf, errorLen)

if inputLen == 0 {
copy(errorBuf, "invalid input parameters, invalid number of pairs\x00")
return 1
}

if inputLen % (EIP2537PreallocateForG2 + EIP2537PreallocateForScalar) != 0 {
copy(errorBuf, "invalid input parameters, invalid input length for G2 multiplication\x00")
return 1
}

// Convert input C pointers to Go slice
input := castBufferToSlice(unsafe.Pointer(javaInputBuf), inputLen)

var exprCount = inputLen / (EIP2537PreallocateForG2 + EIP2537PreallocateForScalar)

g2Points := make([]bls12381.G2Affine, exprCount)
scalars := make([]fr.Element, exprCount)

for i := 0 ; i < exprCount ; i++ {
_, err := g2AffineDecodeInSubGroupVal(&g2Points[i], input[i*288 : (i*288)+256])
if err != nil {
copy(errorBuf, err.Error())
return 1
}

scalars[i].SetBytes(input[(i*288)+256 : (i+1)*288])
}

var affineResult bls12381.G2Affine
// leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit
_, err := affineResult.MultiExp(g2Points, scalars, ecc.MultiExpConfig{NbTasks: int(nbTasks)})
if err != nil {
copy(errorBuf, err.Error())
return 1
}

// marshal the resulting point and encode directly to the output buffer
return nonMontgomeryMarshalG2(&affineResult, javaOutputBuf, errorBuf)
}





//export eip2537blsPairing
func eip2537blsPairing(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int {
inputLen := int(cInputLen)
Expand Down Expand Up @@ -439,25 +541,24 @@ func hasWrongG1Padding(input []byte) bool {
func hasWrongG2Padding(input []byte) bool {
return !isZero(input[:16]) || !isZero(input[64:80] )|| !isZero(input[128:144]) || !isZero(input[192:208])
}


func g1AffineDecodeInSubGroup(input []byte) (*bls12381.G1Affine, error) {
var g1 bls12381.G1Affine
return g1AffineDecodeInSubGroupVal(&g1, input)
}

func g1AffineDecodeInSubGroupVal(g1 *bls12381.G1Affine, input []byte) (*bls12381.G1Affine, error) {
if hasWrongG1Padding(input) {
return nil, ErrMalformedPointPadding
}
var g1x, g1y fp.Element
err := g1x.SetBytesCanonical(input[16:64])
err := g1.X.SetBytesCanonical(input[16:64])
if err != nil {
return nil, err
}
err = g1y.SetBytesCanonical(input[80:128])
err = g1.Y.SetBytesCanonical(input[80:128])
if err != nil {
return nil, err
}

// construct g1affine directly rather than unmarshalling
g1 := &bls12381.G1Affine{X: g1x, Y: g1y}

// do explicit subgroup check
if (!g1.IsInSubGroup()) {
if (!g1.IsOnCurve()) {
Expand Down Expand Up @@ -493,11 +594,15 @@ func g1AffineDecodeOnCurve(input []byte) (*bls12381.G1Affine, error) {
}

func g2AffineDecodeInSubGroup(input []byte) (*bls12381.G2Affine, error) {
var g2 bls12381.G2Affine
return g2AffineDecodeInSubGroupVal(&g2, input)
}

func g2AffineDecodeInSubGroupVal(g2 *bls12381.G2Affine, input []byte) (*bls12381.G2Affine, error) {
if hasWrongG2Padding(input) {
return nil, ErrMalformedPointPadding
}

var g2 bls12381.G2Affine
err := g2.X.A0.SetBytesCanonical(input[16:64])
if err != nil {
return nil, err
Expand All @@ -522,7 +627,7 @@ func g2AffineDecodeInSubGroup(input []byte) (*bls12381.G2Affine, error) {
if (!g2.IsInSubGroup()) {
return nil, ErrSubgroupCheckFailed
}
return &g2, nil;
return g2, nil;
}

func g2AffineDecodeOnCurve(input []byte) (*bls12381.G2Affine, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public class LibGnarkEIP2537 implements Library {
@SuppressWarnings("WeakerAccess")
public static final boolean ENABLED;

// zero implies 'default' degree of parallelism, which is the number of cpu cores available
private static int degreeOfMSMParallelism = 0;

static {
boolean enabled;
try {
Expand Down Expand Up @@ -61,9 +64,10 @@ public static int eip2537_perform_operation(
o_len.setValue(128);
break;
case BLS12_G1MULTIEXP_OPERATION_SHIM_VALUE:
ret = eip2537blsG1MultiExp(i, output, err, i_len,
ret = eip2537blsG1MultiExpParallel(i, output, err, i_len,
EIP2537_PREALLOCATE_FOR_RESULT_BYTES,
EIP2537_PREALLOCATE_FOR_ERROR_BYTES);
EIP2537_PREALLOCATE_FOR_ERROR_BYTES,
degreeOfMSMParallelism);
o_len.setValue(128);
break;
case BLS12_G2ADD_OPERATION_SHIM_VALUE:
Expand All @@ -79,9 +83,10 @@ public static int eip2537_perform_operation(
o_len.setValue(256);
break;
case BLS12_G2MULTIEXP_OPERATION_SHIM_VALUE:
ret = eip2537blsG2MultiExp(i, output, err, i_len,
ret = eip2537blsG2MultiExpParallel(i, output, err, i_len,
EIP2537_PREALLOCATE_FOR_RESULT_BYTES,
EIP2537_PREALLOCATE_FOR_ERROR_BYTES);
EIP2537_PREALLOCATE_FOR_ERROR_BYTES,
degreeOfMSMParallelism);
o_len.setValue(256);
break;
case BLS12_PAIR_OPERATION_SHIM_VALUE:
Expand Down Expand Up @@ -134,6 +139,13 @@ public static native int eip2537blsG1MultiExp(
byte[] error,
int inputSize, int output_len, int err_len);

public static native int eip2537blsG1MultiExpParallel(
byte[] input,
byte[] output,
byte[] error,
int inputSize, int output_len, int err_len,
int nbTasks);

public static native int eip2537blsG2Add(
byte[] input,
byte[] output,
Expand All @@ -152,6 +164,13 @@ public static native int eip2537blsG2MultiExp(
byte[] error,
int inputSize, int output_len, int err_len);

public static native int eip2537blsG2MultiExpParallel(
byte[] input,
byte[] output,
byte[] error,
int inputSize, int output_len, int err_len,
int nbTasks);

public static native int eip2537blsPairing(
byte[] input,
byte[] output,
Expand All @@ -170,4 +189,7 @@ public static native int eip2537blsMapFp2ToG2(
byte[] error,
int inputSize, int output_len, int err_len);

public static void setDegreeOfMSMParallelism(int nbTasks) {
degreeOfMSMParallelism = nbTasks;
}
}

0 comments on commit 7581d1e

Please sign in to comment.