Skip to content

Implement Flag-Based Verification for Recursive SNARK Proofs & In-Circuit Signatures #1432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion std/algebra/emulated/fields_bw6761/e6.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ func (e Ext6) Sub(x, y *E6) *E6 {
}
}

func (e Ext6) IsZero(x *E6) frontend.Variable {
isZero := e.fp.IsZero(&x.A0)
isZero = e.api.And(isZero, e.fp.IsZero(&x.A1))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A2))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A3))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A4))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A5))
return isZero
}

func (e Ext6) Double(x *E6) *E6 {
two := big.NewInt(2)
a0 := e.fp.MulConst(&x.A0, two)
Expand Down Expand Up @@ -1106,7 +1116,6 @@ func (e Ext6) AssertIsEqual(a, b *E6) {
e.fp.AssertIsEqual(&a.A3, &b.A3)
e.fp.AssertIsEqual(&a.A4, &b.A4)
e.fp.AssertIsEqual(&a.A5, &b.A5)

}

func (e Ext6) IsEqual(x, y *E6) frontend.Variable {
Expand Down
14 changes: 14 additions & 0 deletions std/algebra/native/fields_bls12377/e12.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ func (e *E12) Mul(api frontend.API, e1, e2 E12) *E12 {
return e
}

func (e *E12) IsZero(api frontend.API) frontend.Variable {
isZero := e.C0.B0.IsZero(api)
isZero = api.And(isZero, e.C0.B1.IsZero(api))
isZero = api.And(isZero, e.C0.B2.IsZero(api))
isZero = api.And(isZero, e.C1.B0.IsZero(api))
isZero = api.And(isZero, e.C1.B1.IsZero(api))
isZero = api.And(isZero, e.C1.B2.IsZero(api))
return isZero
}

func (e *E12) IsEqual(api frontend.API, x, y *E12) frontend.Variable {
return e.Sub(api, *x, *y).IsZero(api)
}

// Square squares an element in Fp12
func (e *E12) Square(api frontend.API, x E12) *E12 {

Expand Down
11 changes: 11 additions & 0 deletions std/algebra/native/fields_bls24315/e24.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,17 @@ func (e *E24) Mul(api frontend.API, e1, e2 E24) *E24 {
return e
}

func (e *E24) IsZero(api frontend.API) frontend.Variable {
isZero := e.D0.C0.IsZero(api)
isZero = api.And(isZero, e.D0.C1.IsZero(api))
isZero = api.And(isZero, e.D0.C2.IsZero(api))
isZero = api.And(isZero, e.D1.C0.IsZero(api))
isZero = api.And(isZero, e.D1.C1.IsZero(api))
isZero = api.And(isZero, e.D1.C2.IsZero(api))
return isZero
}


// Square squares an element in Fp24
func (e *E24) Square(api frontend.API, x E24) *E24 {

Expand Down
40 changes: 4 additions & 36 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,42 +106,6 @@ func (c *Curve) AssertIsEqual(P, Q *G1Affine) {
P.AssertIsEqual(c.api, *Q)
}

func (c *Pairing) IsEqual(x, y *GT) frontend.Variable {
diff0 := c.api.Sub(&x.C0.B0.A0, &y.C0.B0.A0)
diff1 := c.api.Sub(&x.C0.B0.A1, &y.C0.B0.A1)
diff2 := c.api.Sub(&x.C0.B0.A0, &y.C0.B0.A0)
diff3 := c.api.Sub(&x.C0.B1.A1, &y.C0.B1.A1)
diff4 := c.api.Sub(&x.C0.B1.A0, &y.C0.B1.A0)
diff5 := c.api.Sub(&x.C0.B1.A1, &y.C0.B1.A1)
diff6 := c.api.Sub(&x.C1.B0.A0, &y.C1.B0.A0)
diff7 := c.api.Sub(&x.C1.B0.A1, &y.C1.B0.A1)
diff8 := c.api.Sub(&x.C1.B0.A0, &y.C1.B0.A0)
diff9 := c.api.Sub(&x.C1.B1.A1, &y.C1.B1.A1)
diff10 := c.api.Sub(&x.C1.B1.A0, &y.C1.B1.A0)
diff11 := c.api.Sub(&x.C1.B1.A1, &y.C1.B1.A1)

isZero0 := c.api.IsZero(diff0)
isZero1 := c.api.IsZero(diff1)
isZero2 := c.api.IsZero(diff2)
isZero3 := c.api.IsZero(diff3)
isZero4 := c.api.IsZero(diff4)
isZero5 := c.api.IsZero(diff5)
isZero6 := c.api.IsZero(diff6)
isZero7 := c.api.IsZero(diff7)
isZero8 := c.api.IsZero(diff8)
isZero9 := c.api.IsZero(diff9)
isZero10 := c.api.IsZero(diff10)
isZero11 := c.api.IsZero(diff11)

return c.api.And(
c.api.And(
c.api.And(c.api.And(isZero0, isZero1), c.api.And(isZero2, isZero3)),
c.api.And(c.api.And(isZero4, isZero5), c.api.And(isZero6, isZero7)),
),
c.api.And(c.api.And(isZero8, isZero9), c.api.And(isZero10, isZero11)),
)
}

// Neg negates P and returns the result. Does not modify P.
func (c *Curve) Neg(P *G1Affine) *G1Affine {
res := &G1Affine{
Expand Down Expand Up @@ -362,6 +326,10 @@ func (p *Pairing) AssertIsEqual(e1, e2 *GT) {
e1.AssertIsEqual(p.api, *e2)
}

func (pr *Pairing) IsEqual(e1, e2 *GT) frontend.Variable {
return e1.IsEqual(pr.api, e1, e2)
}

func (pr Pairing) MuxG2(sel frontend.Variable, inputs ...*G2Affine) *G2Affine {
if len(inputs) == 0 {
return nil
Expand Down
73 changes: 4 additions & 69 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,75 +255,6 @@ func NewPairing(api frontend.API) *Pairing {
}
}

func (c *Pairing) IsEqual(x, y *GT) frontend.Variable {
diff0 := c.api.Sub(&x.D0.C0.B0.A0, &y.D0.C0.B0.A0)
diff1 := c.api.Sub(&x.D0.C0.B0.A1, &y.D0.C0.B0.A1)
diff2 := c.api.Sub(&x.D0.C0.B0.A0, &y.D0.C0.B0.A0)
diff3 := c.api.Sub(&x.D0.C0.B1.A1, &y.D0.C0.B1.A1)
diff4 := c.api.Sub(&x.D0.C0.B1.A0, &y.D0.C0.B1.A0)
diff5 := c.api.Sub(&x.D0.C0.B1.A1, &y.D0.C0.B1.A1)
diff6 := c.api.Sub(&x.D0.C1.B0.A0, &y.D0.C1.B0.A0)
diff7 := c.api.Sub(&x.D0.C1.B0.A1, &y.D0.C1.B0.A1)
diff8 := c.api.Sub(&x.D0.C1.B0.A0, &y.D0.C1.B0.A0)
diff9 := c.api.Sub(&x.D0.C1.B1.A1, &y.D0.C1.B1.A1)
diff10 := c.api.Sub(&x.D0.C1.B1.A0, &y.D0.C1.B1.A0)
diff11 := c.api.Sub(&x.D0.C1.B1.A1, &y.D0.C1.B1.A1)
diff12 := c.api.Sub(&x.D1.C0.B0.A0, &y.D1.C0.B0.A0)
diff13 := c.api.Sub(&x.D1.C0.B0.A1, &y.D1.C0.B0.A1)
diff14 := c.api.Sub(&x.D1.C0.B0.A0, &y.D1.C0.B0.A0)
diff15 := c.api.Sub(&x.D1.C0.B1.A1, &y.D1.C0.B1.A1)
diff16 := c.api.Sub(&x.D1.C0.B1.A0, &y.D1.C0.B1.A0)
diff17 := c.api.Sub(&x.D1.C0.B1.A1, &y.D1.C0.B1.A1)
diff18 := c.api.Sub(&x.D1.C1.B0.A0, &y.D1.C1.B0.A0)
diff19 := c.api.Sub(&x.D1.C1.B0.A1, &y.D1.C1.B0.A1)
diff20 := c.api.Sub(&x.D1.C1.B0.A0, &y.D1.C1.B0.A0)
diff21 := c.api.Sub(&x.D1.C1.B1.A1, &y.D1.C1.B1.A1)
diff22 := c.api.Sub(&x.D1.C1.B1.A0, &y.D1.C1.B1.A0)
diff23 := c.api.Sub(&x.D1.C1.B1.A1, &y.D1.C1.B1.A1)

isZero0 := c.api.IsZero(diff0)
isZero1 := c.api.IsZero(diff1)
isZero2 := c.api.IsZero(diff2)
isZero3 := c.api.IsZero(diff3)
isZero4 := c.api.IsZero(diff4)
isZero5 := c.api.IsZero(diff5)
isZero6 := c.api.IsZero(diff6)
isZero7 := c.api.IsZero(diff7)
isZero8 := c.api.IsZero(diff8)
isZero9 := c.api.IsZero(diff9)
isZero10 := c.api.IsZero(diff10)
isZero11 := c.api.IsZero(diff11)
isZero12 := c.api.IsZero(diff12)
isZero13 := c.api.IsZero(diff13)
isZero14 := c.api.IsZero(diff14)
isZero15 := c.api.IsZero(diff15)
isZero16 := c.api.IsZero(diff16)
isZero17 := c.api.IsZero(diff17)
isZero18 := c.api.IsZero(diff18)
isZero19 := c.api.IsZero(diff19)
isZero20 := c.api.IsZero(diff20)
isZero21 := c.api.IsZero(diff21)
isZero22 := c.api.IsZero(diff22)
isZero23 := c.api.IsZero(diff23)

return c.api.And(
c.api.And(
c.api.And(
c.api.And(c.api.And(isZero0, isZero1), c.api.And(isZero2, isZero3)),
c.api.And(c.api.And(isZero4, isZero5), c.api.And(isZero6, isZero7)),
),
c.api.And(
c.api.And(c.api.And(isZero8, isZero9), c.api.And(isZero10, isZero11)),
c.api.And(c.api.And(isZero12, isZero13), c.api.And(isZero14, isZero15)),
),
),
c.api.And(
c.api.And(c.api.And(isZero16, isZero17), c.api.And(isZero18, isZero19)),
c.api.And(c.api.And(isZero20, isZero21), c.api.And(isZero22, isZero23)),
),
)
}

// MillerLoop computes the Miller loop between the pairs of inputs. It doesn't
// modify the inputs. It returns an error if there is a mismatch between the
// lengths of the inputs.
Expand Down Expand Up @@ -570,6 +501,10 @@ func (pr Pairing) MuxGt(sel frontend.Variable, inputs ...*GT) *GT {
return &ret
}

func (pr *Pairing) IsEqual(e1, e2 *GT) frontend.Variable {
return e1.Sub(pr.api, *e1, *e2).IsZero(pr.api)
}

func (p *Pairing) AssertIsOnG1(P *G1Affine) {
panic("not implemented")
}
Expand Down
5 changes: 5 additions & 0 deletions std/algebra/native/twistededwards/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func (c *curve) Neg(p1 Point) Point {
p.neg(c.api, &p1)
return p
}

func (c *curve) IsOnCurve(p1 Point) frontend.Variable {
return p1.isOnCurve(c.api, c.params)
}

func (c *curve) AssertIsOnCurve(p1 Point) {
p1.assertIsOnCurve(c.api, c.params)
}
Expand Down
10 changes: 8 additions & 2 deletions std/algebra/native/twistededwards/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ func (p *Point) neg(api frontend.API, p1 *Point) *Point {
// assertIsOnCurve checks if a point is on the reduced twisted Edwards curve
// a*x² + y² = 1 + d*x²*y².
func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) {
flag := p.isOnCurve(api, curve)
api.AssertIsEqual(flag, 1)
}

// isOnCurve returns 1 if a point is on the reduced twisted Edwards curve
// a*x² + y² = 1 + d*x²*y², 0 otherwise.
func (p *Point) isOnCurve(api frontend.API, curve *CurveParams) frontend.Variable {

xx := api.Mul(p.X, p.X)
yy := api.Mul(p.Y, p.Y)
Expand All @@ -25,8 +32,7 @@ func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) {
dxxyy := api.Mul(dxx, yy)
rhs := api.Add(dxxyy, 1)

api.AssertIsEqual(lhs, rhs)

return api.IsZero(api.Sub(lhs, rhs))
}

// add Adds two points on a twisted edwards curve (eg jubjub)
Expand Down
1 change: 1 addition & 0 deletions std/algebra/native/twistededwards/twistededwards.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Curve interface {
Add(p1, p2 Point) Point
Double(p1 Point) Point
Neg(p1 Point) Point
IsOnCurve(p1 Point) frontend.Variable
AssertIsOnCurve(p1 Point)
ScalarMul(p1 Point, scalar frontend.Variable) Point
DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point
Expand Down
32 changes: 21 additions & 11 deletions std/recursion/groth16/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,16 +607,27 @@ func NewVerifier[FR emulated.FieldParams, G1El algebra.G1ElementT, G2El algebra.
// AssertProof asserts that the SNARK proof holds for the given witness and
// verifying key.
func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El, GtEl], proof Proof[G1El, G2El], witness Witness[FR], opts ...VerifierOption) error {
flag, err := v.ProofIsValid(vk, proof, witness, opts...)
if err != nil {
return err
}
v.api.AssertIsEqual(flag, 1)
return nil
}

// ProofIsValid returns 1 if the SNARK proof holds for the given witness and
// verifying key, and 0 otherwise.
func (v *Verifier[FR, G1El, G2El, GtEl]) ProofIsValid(vk VerifyingKey[G1El, G2El, GtEl], proof Proof[G1El, G2El], witness Witness[FR], opts ...VerifierOption) (frontend.Variable, error) {
if len(vk.CommitmentKeys) != len(proof.Commitments) {
return fmt.Errorf("invalid number of commitments, got %d, expected %d", len(proof.Commitments), len(vk.CommitmentKeys))
return 0, fmt.Errorf("invalid number of commitments, got %d, expected %d", len(proof.Commitments), len(vk.CommitmentKeys))
}
if len(vk.CommitmentKeys) != len(vk.PublicAndCommitmentCommitted) {
return fmt.Errorf("invalid number of commitment keys, got %d, expected %d", len(vk.CommitmentKeys), len(vk.PublicAndCommitmentCommitted))
return 0, fmt.Errorf("invalid number of commitment keys, got %d, expected %d", len(vk.CommitmentKeys), len(vk.PublicAndCommitmentCommitted))
}
var fr FR
nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted)
if len(witness.Public) != nbPublicVars-1 {
return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(witness.Public), len(vk.G1.K)-1)
return 0, fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(witness.Public), len(vk.G1.K)-1)
}

inP := make([]*G1El, len(vk.G1.K)-1) // first is for the one wire, we add it manually after MSM
Expand All @@ -630,11 +641,11 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El,

opt, err := newCfg(opts...)
if err != nil {
return fmt.Errorf("apply options: %w", err)
return 0, fmt.Errorf("apply options: %w", err)
}
hashToField, err := recursion.NewHash(v.api, fr.Modulus(), true)
if err != nil {
return fmt.Errorf("hash to field: %w", err)
return 0, fmt.Errorf("hash to field: %w", err)
}

maxNbPublicCommitted := 0
Expand Down Expand Up @@ -663,16 +674,16 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El,
// explicitly do not verify the commitment as there is nothing
case 1:
if err = v.commitment.AssertCommitment(proof.Commitments[0], proof.CommitmentPok, vk.CommitmentKeys[0], opt.pedopt...); err != nil {
return fmt.Errorf("assert commitment: %w", err)
return 0, fmt.Errorf("assert commitment: %w", err)
}
default:
// TODO: we support only a single commitment in the recursion for now
return fmt.Errorf("multiple commitments are not supported")
return 0, fmt.Errorf("multiple commitments are not supported")
}

kSum, err := v.curve.MultiScalarMul(inP, inS, opt.algopt...)
if err != nil {
return fmt.Errorf("multi scalar mul: %w", err)
return 0, fmt.Errorf("multi scalar mul: %w", err)
}
kSum = v.curve.Add(kSum, &vk.G1.K[0])

Expand All @@ -687,10 +698,9 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El,
}
pairing, err := v.pairing.Pair([]*G1El{kSum, &proof.Krs, &proof.Ar}, []*G2El{&vk.G2.GammaNeg, &vk.G2.DeltaNeg, &proof.Bs})
if err != nil {
return fmt.Errorf("pairing: %w", err)
return 0, fmt.Errorf("pairing: %w", err)
}
v.pairing.AssertIsEqual(pairing, &vk.E)
return nil
return v.pairing.IsEqual(pairing, &vk.E), nil
}

// SwitchVerification key switches the verification key based on the provided
Expand Down
19 changes: 18 additions & 1 deletion std/signature/ecdsa/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base]
//
// We assume that the message msg is already hashed to the scalar field.
func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) {
flag := pk.SignIsValid(api, params, msg, sig)
api.AssertIsEqual(flag, 1)
}

// SignIsValid returns 1 if the signature sig verifies for the message msg and
// public key pk or 0 if not. The curve parameters params define the elliptic
// curve.
//
// We assume that the message msg is already hashed to the scalar field.
func (pk PublicKey[T, S]) SignIsValid(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) frontend.Variable {
cr, err := sw_emulated.New[T, S](api, params)
if err != nil {
panic(err)
Expand All @@ -43,7 +53,14 @@ func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParam
if len(rbits) != len(qxBits) {
panic("non-equal lengths")
}
// store 1 to expect equality
res := frontend.Variable(1)
for i := range rbits {
api.AssertIsEqual(rbits[i], qxBits[i])
// calc the difference between the bits
diff := api.Sub(rbits[i], qxBits[i])
// update the result with the AND of the previous result and the
// equality between the bits (diff == 0)
res = api.And(res, api.IsZero(diff))
}
return res
}
Loading