diff --git a/std/algebra/emulated/fields_bw6761/e6.go b/std/algebra/emulated/fields_bw6761/e6.go index 0c3a81bdc2..804a7d3546 100644 --- a/std/algebra/emulated/fields_bw6761/e6.go +++ b/std/algebra/emulated/fields_bw6761/e6.go @@ -118,6 +118,16 @@ func (e Ext6) Sub(x, y *E6) *E6 { } } +func (e Ext6) IsZero(x *E6) frontend.Variable { + isZero := e.fp.IsZero(&x.A0) + isZero = e.api.And(isZero, e.fp.IsZero(&x.A1)) + isZero = e.api.And(isZero, e.fp.IsZero(&x.A2)) + isZero = e.api.And(isZero, e.fp.IsZero(&x.A3)) + isZero = e.api.And(isZero, e.fp.IsZero(&x.A4)) + isZero = e.api.And(isZero, e.fp.IsZero(&x.A5)) + return isZero +} + func (e Ext6) Double(x *E6) *E6 { two := big.NewInt(2) a0 := e.fp.MulConst(&x.A0, two) @@ -1106,7 +1116,6 @@ func (e Ext6) AssertIsEqual(a, b *E6) { e.fp.AssertIsEqual(&a.A3, &b.A3) e.fp.AssertIsEqual(&a.A4, &b.A4) e.fp.AssertIsEqual(&a.A5, &b.A5) - } func (e Ext6) IsEqual(x, y *E6) frontend.Variable { diff --git a/std/algebra/native/fields_bls12377/e12.go b/std/algebra/native/fields_bls12377/e12.go index d04890bab8..ec977f14b1 100644 --- a/std/algebra/native/fields_bls12377/e12.go +++ b/std/algebra/native/fields_bls12377/e12.go @@ -153,6 +153,20 @@ func (e *E12) Mul(api frontend.API, e1, e2 E12) *E12 { return e } +func (e *E12) IsZero(api frontend.API) frontend.Variable { + isZero := e.C0.B0.IsZero(api) + isZero = api.And(isZero, e.C0.B1.IsZero(api)) + isZero = api.And(isZero, e.C0.B2.IsZero(api)) + isZero = api.And(isZero, e.C1.B0.IsZero(api)) + isZero = api.And(isZero, e.C1.B1.IsZero(api)) + isZero = api.And(isZero, e.C1.B2.IsZero(api)) + return isZero +} + +func (e *E12) IsEqual(api frontend.API, x, y *E12) frontend.Variable { + return e.Sub(api, *x, *y).IsZero(api) +} + // Square squares an element in Fp12 func (e *E12) Square(api frontend.API, x E12) *E12 { diff --git a/std/algebra/native/fields_bls24315/e24.go b/std/algebra/native/fields_bls24315/e24.go index 362c5536a0..10e35c13e6 100644 --- a/std/algebra/native/fields_bls24315/e24.go +++ b/std/algebra/native/fields_bls24315/e24.go @@ -152,6 +152,17 @@ func (e *E24) Mul(api frontend.API, e1, e2 E24) *E24 { return e } +func (e *E24) IsZero(api frontend.API) frontend.Variable { + isZero := e.D0.C0.IsZero(api) + isZero = api.And(isZero, e.D0.C1.IsZero(api)) + isZero = api.And(isZero, e.D0.C2.IsZero(api)) + isZero = api.And(isZero, e.D1.C0.IsZero(api)) + isZero = api.And(isZero, e.D1.C1.IsZero(api)) + isZero = api.And(isZero, e.D1.C2.IsZero(api)) + return isZero +} + + // Square squares an element in Fp24 func (e *E24) Square(api frontend.API, x E24) *E24 { diff --git a/std/algebra/native/sw_bls12377/pairing2.go b/std/algebra/native/sw_bls12377/pairing2.go index 50f2366dca..46db52b84b 100644 --- a/std/algebra/native/sw_bls12377/pairing2.go +++ b/std/algebra/native/sw_bls12377/pairing2.go @@ -106,42 +106,6 @@ func (c *Curve) AssertIsEqual(P, Q *G1Affine) { P.AssertIsEqual(c.api, *Q) } -func (c *Pairing) IsEqual(x, y *GT) frontend.Variable { - diff0 := c.api.Sub(&x.C0.B0.A0, &y.C0.B0.A0) - diff1 := c.api.Sub(&x.C0.B0.A1, &y.C0.B0.A1) - diff2 := c.api.Sub(&x.C0.B0.A0, &y.C0.B0.A0) - diff3 := c.api.Sub(&x.C0.B1.A1, &y.C0.B1.A1) - diff4 := c.api.Sub(&x.C0.B1.A0, &y.C0.B1.A0) - diff5 := c.api.Sub(&x.C0.B1.A1, &y.C0.B1.A1) - diff6 := c.api.Sub(&x.C1.B0.A0, &y.C1.B0.A0) - diff7 := c.api.Sub(&x.C1.B0.A1, &y.C1.B0.A1) - diff8 := c.api.Sub(&x.C1.B0.A0, &y.C1.B0.A0) - diff9 := c.api.Sub(&x.C1.B1.A1, &y.C1.B1.A1) - diff10 := c.api.Sub(&x.C1.B1.A0, &y.C1.B1.A0) - diff11 := c.api.Sub(&x.C1.B1.A1, &y.C1.B1.A1) - - isZero0 := c.api.IsZero(diff0) - isZero1 := c.api.IsZero(diff1) - isZero2 := c.api.IsZero(diff2) - isZero3 := c.api.IsZero(diff3) - isZero4 := c.api.IsZero(diff4) - isZero5 := c.api.IsZero(diff5) - isZero6 := c.api.IsZero(diff6) - isZero7 := c.api.IsZero(diff7) - isZero8 := c.api.IsZero(diff8) - isZero9 := c.api.IsZero(diff9) - isZero10 := c.api.IsZero(diff10) - isZero11 := c.api.IsZero(diff11) - - return c.api.And( - c.api.And( - c.api.And(c.api.And(isZero0, isZero1), c.api.And(isZero2, isZero3)), - c.api.And(c.api.And(isZero4, isZero5), c.api.And(isZero6, isZero7)), - ), - c.api.And(c.api.And(isZero8, isZero9), c.api.And(isZero10, isZero11)), - ) -} - // Neg negates P and returns the result. Does not modify P. func (c *Curve) Neg(P *G1Affine) *G1Affine { res := &G1Affine{ @@ -362,6 +326,10 @@ func (p *Pairing) AssertIsEqual(e1, e2 *GT) { e1.AssertIsEqual(p.api, *e2) } +func (pr *Pairing) IsEqual(e1, e2 *GT) frontend.Variable { + return e1.IsEqual(pr.api, e1, e2) +} + func (pr Pairing) MuxG2(sel frontend.Variable, inputs ...*G2Affine) *G2Affine { if len(inputs) == 0 { return nil diff --git a/std/algebra/native/sw_bls24315/pairing2.go b/std/algebra/native/sw_bls24315/pairing2.go index bdc3d21fa5..926bdd5790 100644 --- a/std/algebra/native/sw_bls24315/pairing2.go +++ b/std/algebra/native/sw_bls24315/pairing2.go @@ -255,75 +255,6 @@ func NewPairing(api frontend.API) *Pairing { } } -func (c *Pairing) IsEqual(x, y *GT) frontend.Variable { - diff0 := c.api.Sub(&x.D0.C0.B0.A0, &y.D0.C0.B0.A0) - diff1 := c.api.Sub(&x.D0.C0.B0.A1, &y.D0.C0.B0.A1) - diff2 := c.api.Sub(&x.D0.C0.B0.A0, &y.D0.C0.B0.A0) - diff3 := c.api.Sub(&x.D0.C0.B1.A1, &y.D0.C0.B1.A1) - diff4 := c.api.Sub(&x.D0.C0.B1.A0, &y.D0.C0.B1.A0) - diff5 := c.api.Sub(&x.D0.C0.B1.A1, &y.D0.C0.B1.A1) - diff6 := c.api.Sub(&x.D0.C1.B0.A0, &y.D0.C1.B0.A0) - diff7 := c.api.Sub(&x.D0.C1.B0.A1, &y.D0.C1.B0.A1) - diff8 := c.api.Sub(&x.D0.C1.B0.A0, &y.D0.C1.B0.A0) - diff9 := c.api.Sub(&x.D0.C1.B1.A1, &y.D0.C1.B1.A1) - diff10 := c.api.Sub(&x.D0.C1.B1.A0, &y.D0.C1.B1.A0) - diff11 := c.api.Sub(&x.D0.C1.B1.A1, &y.D0.C1.B1.A1) - diff12 := c.api.Sub(&x.D1.C0.B0.A0, &y.D1.C0.B0.A0) - diff13 := c.api.Sub(&x.D1.C0.B0.A1, &y.D1.C0.B0.A1) - diff14 := c.api.Sub(&x.D1.C0.B0.A0, &y.D1.C0.B0.A0) - diff15 := c.api.Sub(&x.D1.C0.B1.A1, &y.D1.C0.B1.A1) - diff16 := c.api.Sub(&x.D1.C0.B1.A0, &y.D1.C0.B1.A0) - diff17 := c.api.Sub(&x.D1.C0.B1.A1, &y.D1.C0.B1.A1) - diff18 := c.api.Sub(&x.D1.C1.B0.A0, &y.D1.C1.B0.A0) - diff19 := c.api.Sub(&x.D1.C1.B0.A1, &y.D1.C1.B0.A1) - diff20 := c.api.Sub(&x.D1.C1.B0.A0, &y.D1.C1.B0.A0) - diff21 := c.api.Sub(&x.D1.C1.B1.A1, &y.D1.C1.B1.A1) - diff22 := c.api.Sub(&x.D1.C1.B1.A0, &y.D1.C1.B1.A0) - diff23 := c.api.Sub(&x.D1.C1.B1.A1, &y.D1.C1.B1.A1) - - isZero0 := c.api.IsZero(diff0) - isZero1 := c.api.IsZero(diff1) - isZero2 := c.api.IsZero(diff2) - isZero3 := c.api.IsZero(diff3) - isZero4 := c.api.IsZero(diff4) - isZero5 := c.api.IsZero(diff5) - isZero6 := c.api.IsZero(diff6) - isZero7 := c.api.IsZero(diff7) - isZero8 := c.api.IsZero(diff8) - isZero9 := c.api.IsZero(diff9) - isZero10 := c.api.IsZero(diff10) - isZero11 := c.api.IsZero(diff11) - isZero12 := c.api.IsZero(diff12) - isZero13 := c.api.IsZero(diff13) - isZero14 := c.api.IsZero(diff14) - isZero15 := c.api.IsZero(diff15) - isZero16 := c.api.IsZero(diff16) - isZero17 := c.api.IsZero(diff17) - isZero18 := c.api.IsZero(diff18) - isZero19 := c.api.IsZero(diff19) - isZero20 := c.api.IsZero(diff20) - isZero21 := c.api.IsZero(diff21) - isZero22 := c.api.IsZero(diff22) - isZero23 := c.api.IsZero(diff23) - - return c.api.And( - c.api.And( - c.api.And( - c.api.And(c.api.And(isZero0, isZero1), c.api.And(isZero2, isZero3)), - c.api.And(c.api.And(isZero4, isZero5), c.api.And(isZero6, isZero7)), - ), - c.api.And( - c.api.And(c.api.And(isZero8, isZero9), c.api.And(isZero10, isZero11)), - c.api.And(c.api.And(isZero12, isZero13), c.api.And(isZero14, isZero15)), - ), - ), - c.api.And( - c.api.And(c.api.And(isZero16, isZero17), c.api.And(isZero18, isZero19)), - c.api.And(c.api.And(isZero20, isZero21), c.api.And(isZero22, isZero23)), - ), - ) -} - // MillerLoop computes the Miller loop between the pairs of inputs. It doesn't // modify the inputs. It returns an error if there is a mismatch between the // lengths of the inputs. @@ -570,6 +501,10 @@ func (pr Pairing) MuxGt(sel frontend.Variable, inputs ...*GT) *GT { return &ret } +func (pr *Pairing) IsEqual(e1, e2 *GT) frontend.Variable { + return e1.Sub(pr.api, *e1, *e2).IsZero(pr.api) +} + func (p *Pairing) AssertIsOnG1(P *G1Affine) { panic("not implemented") } diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index 9349e276e5..ee53c05cb1 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -41,6 +41,11 @@ func (c *curve) Neg(p1 Point) Point { p.neg(c.api, &p1) return p } + +func (c *curve) IsOnCurve(p1 Point) frontend.Variable { + return p1.isOnCurve(c.api, c.params) +} + func (c *curve) AssertIsOnCurve(p1 Point) { p1.assertIsOnCurve(c.api, c.params) } diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index fde04e192c..f5f8850f04 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -15,6 +15,13 @@ func (p *Point) neg(api frontend.API, p1 *Point) *Point { // assertIsOnCurve checks if a point is on the reduced twisted Edwards curve // a*x² + y² = 1 + d*x²*y². func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) { + flag := p.isOnCurve(api, curve) + api.AssertIsEqual(flag, 1) +} + +// isOnCurve returns 1 if a point is on the reduced twisted Edwards curve +// a*x² + y² = 1 + d*x²*y², 0 otherwise. +func (p *Point) isOnCurve(api frontend.API, curve *CurveParams) frontend.Variable { xx := api.Mul(p.X, p.X) yy := api.Mul(p.Y, p.Y) @@ -25,8 +32,7 @@ func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) { dxxyy := api.Mul(dxx, yy) rhs := api.Add(dxxyy, 1) - api.AssertIsEqual(lhs, rhs) - + return api.IsZero(api.Sub(lhs, rhs)) } // add Adds two points on a twisted edwards curve (eg jubjub) diff --git a/std/algebra/native/twistededwards/twistededwards.go b/std/algebra/native/twistededwards/twistededwards.go index f50ae666d4..268355c495 100644 --- a/std/algebra/native/twistededwards/twistededwards.go +++ b/std/algebra/native/twistededwards/twistededwards.go @@ -27,6 +27,7 @@ type Curve interface { Add(p1, p2 Point) Point Double(p1 Point) Point Neg(p1 Point) Point + IsOnCurve(p1 Point) frontend.Variable AssertIsOnCurve(p1 Point) ScalarMul(p1 Point, scalar frontend.Variable) Point DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point diff --git a/std/recursion/groth16/verifier.go b/std/recursion/groth16/verifier.go index 1859d17b69..95de8fbd40 100644 --- a/std/recursion/groth16/verifier.go +++ b/std/recursion/groth16/verifier.go @@ -607,16 +607,27 @@ func NewVerifier[FR emulated.FieldParams, G1El algebra.G1ElementT, G2El algebra. // AssertProof asserts that the SNARK proof holds for the given witness and // verifying key. func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El, GtEl], proof Proof[G1El, G2El], witness Witness[FR], opts ...VerifierOption) error { + flag, err := v.ProofIsValid(vk, proof, witness, opts...) + if err != nil { + return err + } + v.api.AssertIsEqual(flag, 1) + return nil +} + +// ProofIsValid returns 1 if the SNARK proof holds for the given witness and +// verifying key, and 0 otherwise. +func (v *Verifier[FR, G1El, G2El, GtEl]) ProofIsValid(vk VerifyingKey[G1El, G2El, GtEl], proof Proof[G1El, G2El], witness Witness[FR], opts ...VerifierOption) (frontend.Variable, error) { if len(vk.CommitmentKeys) != len(proof.Commitments) { - return fmt.Errorf("invalid number of commitments, got %d, expected %d", len(proof.Commitments), len(vk.CommitmentKeys)) + return 0, fmt.Errorf("invalid number of commitments, got %d, expected %d", len(proof.Commitments), len(vk.CommitmentKeys)) } if len(vk.CommitmentKeys) != len(vk.PublicAndCommitmentCommitted) { - return fmt.Errorf("invalid number of commitment keys, got %d, expected %d", len(vk.CommitmentKeys), len(vk.PublicAndCommitmentCommitted)) + return 0, fmt.Errorf("invalid number of commitment keys, got %d, expected %d", len(vk.CommitmentKeys), len(vk.PublicAndCommitmentCommitted)) } var fr FR nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) if len(witness.Public) != nbPublicVars-1 { - return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(witness.Public), len(vk.G1.K)-1) + return 0, fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(witness.Public), len(vk.G1.K)-1) } inP := make([]*G1El, len(vk.G1.K)-1) // first is for the one wire, we add it manually after MSM @@ -630,11 +641,11 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El, opt, err := newCfg(opts...) if err != nil { - return fmt.Errorf("apply options: %w", err) + return 0, fmt.Errorf("apply options: %w", err) } hashToField, err := recursion.NewHash(v.api, fr.Modulus(), true) if err != nil { - return fmt.Errorf("hash to field: %w", err) + return 0, fmt.Errorf("hash to field: %w", err) } maxNbPublicCommitted := 0 @@ -663,16 +674,16 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El, // explicitly do not verify the commitment as there is nothing case 1: if err = v.commitment.AssertCommitment(proof.Commitments[0], proof.CommitmentPok, vk.CommitmentKeys[0], opt.pedopt...); err != nil { - return fmt.Errorf("assert commitment: %w", err) + return 0, fmt.Errorf("assert commitment: %w", err) } default: // TODO: we support only a single commitment in the recursion for now - return fmt.Errorf("multiple commitments are not supported") + return 0, fmt.Errorf("multiple commitments are not supported") } kSum, err := v.curve.MultiScalarMul(inP, inS, opt.algopt...) if err != nil { - return fmt.Errorf("multi scalar mul: %w", err) + return 0, fmt.Errorf("multi scalar mul: %w", err) } kSum = v.curve.Add(kSum, &vk.G1.K[0]) @@ -687,10 +698,9 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El, } pairing, err := v.pairing.Pair([]*G1El{kSum, &proof.Krs, &proof.Ar}, []*G2El{&vk.G2.GammaNeg, &vk.G2.DeltaNeg, &proof.Bs}) if err != nil { - return fmt.Errorf("pairing: %w", err) + return 0, fmt.Errorf("pairing: %w", err) } - v.pairing.AssertIsEqual(pairing, &vk.E) - return nil + return v.pairing.IsEqual(pairing, &vk.E), nil } // SwitchVerification key switches the verification key based on the provided diff --git a/std/signature/ecdsa/ecdsa.go b/std/signature/ecdsa/ecdsa.go index 201a802323..31ccb1a4da 100644 --- a/std/signature/ecdsa/ecdsa.go +++ b/std/signature/ecdsa/ecdsa.go @@ -19,6 +19,16 @@ type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base] // // We assume that the message msg is already hashed to the scalar field. func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { + flag := pk.SignIsValid(api, params, msg, sig) + api.AssertIsEqual(flag, 1) +} + +// SignIsValid returns 1 if the signature sig verifies for the message msg and +// public key pk or 0 if not. The curve parameters params define the elliptic +// curve. +// +// We assume that the message msg is already hashed to the scalar field. +func (pk PublicKey[T, S]) SignIsValid(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) frontend.Variable { cr, err := sw_emulated.New[T, S](api, params) if err != nil { panic(err) @@ -43,7 +53,14 @@ func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParam if len(rbits) != len(qxBits) { panic("non-equal lengths") } + // store 1 to expect equality + res := frontend.Variable(1) for i := range rbits { - api.AssertIsEqual(rbits[i], qxBits[i]) + // calc the difference between the bits + diff := api.Sub(rbits[i], qxBits[i]) + // update the result with the AND of the previous result and the + // equality between the bits (diff == 0) + res = api.And(res, api.IsZero(diff)) } + return res } diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index e4a3a41f77..5c49cb9ac6 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -36,10 +36,22 @@ type Signature struct { S frontend.Variable } -// Verify verifies an eddsa signature using MiMC hash function +// Verify checks that an eddsa signature verifies for the message msg and +// public key pk provided using MiMC hash function. // cf https://en.wikipedia.org/wiki/EdDSA func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.FieldHasher) error { + isValid, err := SignIsValid(curve, sig, msg, pubKey, hash) + if err != nil { + return err + } + curve.API().AssertIsEqual(isValid, 1) + return nil +} +// SignIsValid returns 1 if the signature sig verifies an eddsa signature +// using MiMC hash function for the message msg and public key pk or 0 if not. +// cf https://en.wikipedia.org/wiki/EdDSA +func SignIsValid(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.FieldHasher) (frontend.Variable, error) { // compute H(R, A, M) hash.Write(sig.R.X) hash.Write(sig.R.Y) @@ -56,7 +68,9 @@ func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pu //[S]G-[H(R,A,M)]*A _A := curve.Neg(pubKey.A) Q := curve.DoubleBaseScalarMul(base, _A, sig.S, hRAM) - curve.AssertIsOnCurve(Q) + // check if Q is on the curve, if not multiply by 0 + isOnCurve := curve.IsOnCurve(Q) + Q = curve.ScalarMul(Q, isOnCurve) //[S]G-[H(R,A,M)]*A-R Q = curve.Add(curve.Neg(Q), sig.R) @@ -66,7 +80,7 @@ func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pu if !curve.Params().Cofactor.IsUint64() { err := errors.New("invalid cofactor") log.Err(err).Str("cofactor", curve.Params().Cofactor.String()).Send() - return err + return nil, err } cofactor := curve.Params().Cofactor.Uint64() switch cofactor { @@ -78,10 +92,10 @@ func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pu log.Warn().Str("cofactor", curve.Params().Cofactor.String()).Msg("curve cofactor is not implemented") } - curve.API().AssertIsEqual(Q.X, 0) - curve.API().AssertIsEqual(Q.Y, 1) - - return nil + zeroX := curve.API().IsZero(Q.X) + oneY := curve.API().IsZero(curve.API().Sub(Q.Y, 1)) + expectedPoint := curve.API().And(zeroX, oneY) + return curve.API().And(isOnCurve, expectedPoint), nil } // Assign is a helper to assigned a compressed binary public key representation into its uncompressed form