From 5e00636d28861cfb8c952b48efdebdb95ebd093b Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Thu, 30 Jan 2025 18:45:25 -0600 Subject: [PATCH 01/14] feat: merkle damgard and poseidon2 --- std/hash/hash.go | 42 ++++++++++++++++++++++++-- std/hash/poseidon2/posiedon2.go | 33 ++++++++++++++++++++ std/permutation/poseidon2/poseidon2.go | 9 ++++++ 3 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 std/hash/poseidon2/posiedon2.go diff --git a/std/hash/hash.go b/std/hash/hash.go index 80a1e5456..8b3c52d2b 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -5,7 +5,7 @@ package hash import ( - "errors" + "fmt" "sync" "github.com/consensys/gnark/frontend" @@ -58,7 +58,7 @@ func GetFieldHasher(name string, api frontend.API) (FieldHasher, error) { defer lock.RUnlock() builder, ok := builderRegistry[name] if !ok { - return nil, errors.New("hash function not found") + return nil, fmt.Errorf("hash function \"%s\" not registered", name) } return builder(api) } @@ -87,3 +87,41 @@ type BinaryFixedLengthHasher interface { // FixedLengthSum returns digest of the first length bytes. FixedLengthSum(length frontend.Variable) []uints.U8 } + +// CompressionFunction is a 2 to 1 function +type CompressionFunction interface { + Apply(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable // TODO @Tabaie @ThomasPiellard better name +} + +// Merkle-Damgard is a generic transformation that turns +type merkleDamgardHasher struct { + state frontend.Variable + iv frontend.Variable + f CompressionFunction + api frontend.API +} + +// NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash +// initialState is a value whose preimage is not known +func NewMerkleDamgardHasher(api frontend.API, f CompressionFunction, initialState frontend.Variable) FieldHasher { + return &merkleDamgardHasher{ + state: initialState, + iv: initialState, + f: f, + api: api, + } +} + +func (h *merkleDamgardHasher) Reset() { + h.state = h.iv +} + +func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { + for _, d := range data { + h.state = h.f.Apply(h.api, h.state, d) + } +} + +func (h *merkleDamgardHasher) Sum() frontend.Variable { + return h.state +} diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/posiedon2.go new file mode 100644 index 000000000..2712ea0fe --- /dev/null +++ b/std/hash/poseidon2/posiedon2.go @@ -0,0 +1,33 @@ +package poseidon2 + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/std/hash" + poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" +) + +func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { + curve := utils.FieldToCurve(api.Compiler().Field()) + params, ok := parameters[curve] + if !ok { + return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String()) + } + seed := fmt.Sprintf("Poseidon2 hash: curve=%s-rF=%d-rP=%d-t=2", curve.String(), params.rF, params.rP) + f := poseidon2.NewHash(2, params.d, params.rF, params.rP, seed, curve) + return hash.NewMerkleDamgardHasher(api, &f, 0), nil +} + +var parameters = map[ecc.ID]struct { + d int + rF int + rP int +}{ + ecc.BLS12_377: { + rF: 6, + rP: 26, + d: 17, + }, +} diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 00fb9306f..8f22d345f 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -282,3 +282,12 @@ func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error { return nil } + +// Apply aliases Permutation in the t=2 case +// implements hash.CompressionFunction +func (h *Hash) Apply(api frontend.API, l, r frontend.Variable) frontend.Variable { + if h.params.t != 2 { + panic("poseidon2: Apply can only be used when t=2") + } + return h.Permutation(api, []frontend.Variable{l, r}) +} From 63038e68e4b902c2f329d53c96ca9b4151037c00 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:55:59 -0600 Subject: [PATCH 02/14] fix bad use of h.Permutation, and test --- std/hash/poseidon2/poseidon2_test.go | 19 ++++++++++ std/permutation/poseidon2/poseidon2.go | 6 ++- test/quick.go | 51 ++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 std/hash/poseidon2/poseidon2_test.go create mode 100644 test/quick.go diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go new file mode 100644 index 000000000..055650afe --- /dev/null +++ b/std/hash/poseidon2/poseidon2_test.go @@ -0,0 +1,19 @@ +package poseidon2 + +import ( + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" + "testing" +) + +func TestPoseidon2Hash(t *testing.T) { + test.SingleFunction(ecc.BLS12_377, func(api frontend.API) []frontend.Variable { + hsh, err := NewPoseidon2(api) + require.NoError(t, err) + hsh.Write(0, 1, 2, 3, 4) + api.AssertIsDifferent(hsh.Sum(), 0) // TODO add test vectors + return nil + })(t) +} diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 8f22d345f..51ba0553a 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -289,5 +289,9 @@ func (h *Hash) Apply(api frontend.API, l, r frontend.Variable) frontend.Variable if h.params.t != 2 { panic("poseidon2: Apply can only be used when t=2") } - return h.Permutation(api, []frontend.Variable{l, r}) + vars := [2]frontend.Variable{l, r} + if err := h.Permutation(api, vars[:]); err != nil { + panic(err) // this would never happen + } + return vars[1] } diff --git a/test/quick.go b/test/quick.go new file mode 100644 index 000000000..988103f5c --- /dev/null +++ b/test/quick.go @@ -0,0 +1,51 @@ +package test + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/require" + "testing" +) + +var snarkFunctionStore = make(map[uint64]func(frontend.API) []frontend.Variable) // todo make thread safe +type snarkFunctionTestCircuit struct { + Outs []frontend.Variable + funcId uint64 // this workaround is necessary because deepEquals fails on objects with function fields +} + +func (c *snarkFunctionTestCircuit) Define(api frontend.API) error { + outs := snarkFunctionStore[c.funcId](api) + delete(snarkFunctionStore, c.funcId) + + // todo replace with SliceEquals + if len(outs) != len(c.Outs) { + return errors.New("SingleFunction: unexpected number of output") + } + for i := range outs { + api.AssertIsEqual(outs[i], c.Outs[i]) + } + return nil +} + +// SingleFunction returns a test function that can run a simple circuit consisting of function f, and match its output with outs +func SingleFunction(curve ecc.ID, f func(frontend.API) []frontend.Variable, outs ...frontend.Variable) func(*testing.T) { + + return func(t *testing.T) { + c := snarkFunctionTestCircuit{ + Outs: make([]frontend.Variable, len(outs)), + } + var b [8]byte + _, err := rand.Read(b[:]) + require.NoError(t, err) + c.funcId = binary.BigEndian.Uint64(b[:]) + snarkFunctionStore[c.funcId] = f + + a := snarkFunctionTestCircuit{ + Outs: outs, + } + require.NoError(t, IsSolved(&c, &a, curve.ScalarField())) + } +} From a8e353f2dc5f6ad1e9e0e1559bf75f2cd855196f Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Sun, 2 Feb 2025 20:17:09 -0600 Subject: [PATCH 03/14] docs: remove incomplete comment --- std/hash/hash.go | 1 - 1 file changed, 1 deletion(-) diff --git a/std/hash/hash.go b/std/hash/hash.go index 8b3c52d2b..ed14d8c6d 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -93,7 +93,6 @@ type CompressionFunction interface { Apply(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable // TODO @Tabaie @ThomasPiellard better name } -// Merkle-Damgard is a generic transformation that turns type merkleDamgardHasher struct { state frontend.Variable iv frontend.Variable From 1b57817bca1d94ea1e986ec76f526482bb6aaf27 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Sun, 2 Feb 2025 21:34:56 -0600 Subject: [PATCH 04/14] test: against gnark-crypto --- go.mod | 2 +- go.sum | 4 ++-- std/hash/poseidon2/poseidon2_test.go | 14 +++++++++++--- std/hash/poseidon2/posiedon2.go | 3 ++- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index dbeef7830..0175a3a03 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.27 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.15.0 + github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 5f24bbd3a..94084a48c 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.27 h1:j6hKUrGAy/H+gpNrpLU3I26n1yc+VMGmd6ID5+gAh github.com/consensys/bavard v0.1.27/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.15.0 h1:OXsWnhheHV59eXIzhL5OIexa/vqTK8wtRYQCtwfMDtY= -github.com/consensys/gnark-crypto v0.15.0/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= +github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 h1:PuRSTn2hpFm+mqysWl/hjTU2AvXYMNZT1nvxQT5j5PY= +github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 055650afe..001521854 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -2,6 +2,7 @@ package poseidon2 import ( "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" @@ -9,11 +10,18 @@ import ( ) func TestPoseidon2Hash(t *testing.T) { + // prepare expected output + h := poseidon2.NewPoseidon2() + for i := range 5 { + _, err := h.Write([]byte{byte(i)}) + require.NoError(t, err) + } + res := h.Sum(nil) + test.SingleFunction(ecc.BLS12_377, func(api frontend.API) []frontend.Variable { hsh, err := NewPoseidon2(api) require.NoError(t, err) hsh.Write(0, 1, 2, 3, 4) - api.AssertIsDifferent(hsh.Sum(), 0) // TODO add test vectors - return nil - })(t) + return []frontend.Variable{hsh.Sum()} + }, res)(t) } diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/posiedon2.go index 2712ea0fe..8e19ebc47 100644 --- a/std/hash/poseidon2/posiedon2.go +++ b/std/hash/poseidon2/posiedon2.go @@ -7,6 +7,7 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/hash" poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" + "strings" ) func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { @@ -15,7 +16,7 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { if !ok { return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String()) } - seed := fmt.Sprintf("Poseidon2 hash: curve=%s-rF=%d-rP=%d-t=2", curve.String(), params.rF, params.rP) + seed := fmt.Sprintf("Poseidon2 hash for %s with t=2, rF=%d, rP=%d, d=%d", strings.ToUpper(curve.String()), params.rF, params.rP, params.d) f := poseidon2.NewHash(2, params.d, params.rF, params.rP, seed, curve) return hash.NewMerkleDamgardHasher(api, &f, 0), nil } From 37a79fa44d7674432acfedf159c792a858ba2ca1 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:11:27 -0600 Subject: [PATCH 05/14] chore: integrate gnark-crypto refactors --- go.mod | 2 +- go.sum | 4 +- std/hash/hash.go | 4 +- std/hash/poseidon2/poseidon2_test.go | 2 +- std/hash/poseidon2/posiedon2.go | 4 +- std/permutation/poseidon2/poseidon2.go | 111 ++++++++++---------- std/permutation/poseidon2/poseidon2_test.go | 34 +++--- 7 files changed, 81 insertions(+), 80 deletions(-) diff --git a/go.mod b/go.mod index 0175a3a03..b7b70697e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.27 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 + github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 94084a48c..200c8c71e 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.27 h1:j6hKUrGAy/H+gpNrpLU3I26n1yc+VMGmd6ID5+gAh github.com/consensys/bavard v0.1.27/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 h1:PuRSTn2hpFm+mqysWl/hjTU2AvXYMNZT1nvxQT5j5PY= -github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= +github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca h1:u6iXwMBfbXODF+hDSwKSTBg6yfD3+eMX6o3PILAK474= +github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/hash/hash.go b/std/hash/hash.go index ed14d8c6d..66452a797 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -90,7 +90,7 @@ type BinaryFixedLengthHasher interface { // CompressionFunction is a 2 to 1 function type CompressionFunction interface { - Apply(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable // TODO @Tabaie @ThomasPiellard better name + Compress(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable } type merkleDamgardHasher struct { @@ -117,7 +117,7 @@ func (h *merkleDamgardHasher) Reset() { func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { for _, d := range data { - h.state = h.f.Apply(h.api, h.state, d) + h.state = h.f.Compress(h.api, h.state, d) } } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 001521854..72293926f 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -11,7 +11,7 @@ import ( func TestPoseidon2Hash(t *testing.T) { // prepare expected output - h := poseidon2.NewPoseidon2() + h := poseidon2.NewMerkleDamgardHasher() for i := range 5 { _, err := h.Write([]byte{byte(i)}) require.NoError(t, err) diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/posiedon2.go index 8e19ebc47..cad8cea68 100644 --- a/std/hash/poseidon2/posiedon2.go +++ b/std/hash/poseidon2/posiedon2.go @@ -7,7 +7,6 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/hash" poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" - "strings" ) func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { @@ -16,8 +15,7 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { if !ok { return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String()) } - seed := fmt.Sprintf("Poseidon2 hash for %s with t=2, rF=%d, rP=%d, d=%d", strings.ToUpper(curve.String()), params.rF, params.rP, params.d) - f := poseidon2.NewHash(2, params.d, params.rF, params.rP, seed, curve) + f := poseidon2.NewHash(2, params.d, params.rF, params.rP, curve) return hash.NewMerkleDamgardHasher(api, &f, 0), nil } diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 51ba0553a..ddfdbaa5d 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -2,6 +2,7 @@ package poseidon import ( "errors" + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -19,7 +20,7 @@ var ( ErrInvalidSizebuffer = errors.New("the size of the input should match the size of the hash buffer") ) -type Hash struct { +type Permutation struct { params parameters } @@ -45,77 +46,79 @@ type parameters struct { roundKeys [][]big.Int } -func NewHash(t, d, rf, rp int, seed string, curve ecc.ID) Hash { +func NewHash(t, d, rf, rp int, curve ecc.ID) Permutation { params := parameters{t: t, d: d, rF: rf, rP: rp} if curve == ecc.BN254 { - rc := poseidonbn254.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbn254.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS12_381 { - rc := poseidonbls12381.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls12381.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS12_377 { - rc := poseidonbls12377.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls12377.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BW6_761 { - rc := poseidonbw6761.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbw6761.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BW6_633 { - rc := poseidonbw6633.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbw6633.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS24_315 { - rc := poseidonbls24315.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls24315.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS24_317 { - rc := poseidonbls24317.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls24317.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } + } else { + panic(fmt.Errorf("curve %s not supported", curve.String())) } - return Hash{params: params} + return Permutation{params: params} } // sBox applies the sBox on buffer[index] -func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) { +func (h *Permutation) sBox(api frontend.API, index int, input []frontend.Variable) { tmp := input[index] if h.params.d == 3 { input[index] = api.Mul(input[index], input[index]) @@ -149,7 +152,7 @@ func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) { // (1 1 4 6) // on chunks of 4 elements on each part of the buffer // see https://eprint.iacr.org/2023/323.pdf appendix B for the addition chain -func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) { +func (h *Permutation) matMulM4InPlace(api frontend.API, s []frontend.Variable) { c := len(s) / 4 for i := 0; i < c; i++ { t0 := api.Add(s[4*i], s[4*i+1]) // s0+s1 @@ -176,7 +179,7 @@ func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) { // // when t=0[4], the buffer is multiplied by circ(2M4,M4,..,M4) // see https://eprint.iacr.org/2023/323.pdf -func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable) { +func (h *Permutation) matMulExternalInPlace(api frontend.API, input []frontend.Variable) { if h.params.t == 2 { tmp := api.Add(input[0], input[1]) @@ -213,7 +216,7 @@ func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable // when t=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, -func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable) { +func (h *Permutation) matMulInternalInPlace(api frontend.API, input []frontend.Variable) { if h.params.t == 2 { sum := api.Add(input[0], input[1]) input[0] = api.Add(input[0], sum) @@ -241,13 +244,13 @@ func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable } // addRoundKeyInPlace adds the round-th key to the buffer -func (h *Hash) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) { +func (h *Permutation) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) { for i := 0; i < len(h.params.roundKeys[round]); i++ { input[i] = api.Add(input[i], h.params.roundKeys[round][i]) } } -func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error { +func (h *Permutation) Permutation(api frontend.API, input []frontend.Variable) error { if len(input) != h.params.t { return ErrInvalidSizebuffer } @@ -283,11 +286,11 @@ func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error { return nil } -// Apply aliases Permutation in the t=2 case +// Compress aliases Permutation in the t=2 case // implements hash.CompressionFunction -func (h *Hash) Apply(api frontend.API, l, r frontend.Variable) frontend.Variable { +func (h *Permutation) Compress(api frontend.API, l, r frontend.Variable) frontend.Variable { if h.params.t != 2 { - panic("poseidon2: Apply can only be used when t=2") + panic("poseidon2: Compress can only be used when t=2") } vars := [2]frontend.Variable{l, r} if err := h.Permutation(api, vars[:]); err != nil { diff --git a/std/permutation/poseidon2/poseidon2_test.go b/std/permutation/poseidon2/poseidon2_test.go index f14c07813..66ed81488 100644 --- a/std/permutation/poseidon2/poseidon2_test.go +++ b/std/permutation/poseidon2/poseidon2_test.go @@ -44,7 +44,7 @@ type circuitParams struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.seed, c.params.id) + h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.id) h.Permutation(api, c.Input) for i := 0; i < len(c.Input); i++ { api.AssertIsEqual(c.Output[i], c.Input[i]) @@ -68,11 +68,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbn254.NewHash( + h := poseidonbn254.NewPermutation( params[ecc.BN254].t, params[ecc.BN254].rf, params[ecc.BN254].rp, - "seed") + ) var in, out [3]frbn254.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -101,11 +101,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls12377.NewHash( + h := poseidonbls12377.NewPermutation( params[ecc.BLS12_377].t, params[ecc.BLS12_377].rf, params[ecc.BLS12_377].rp, - "seed") + ) var in, out [3]frbls12377.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -134,11 +134,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls12381.NewHash( + h := poseidonbls12381.NewPermutation( params[ecc.BLS12_381].t, params[ecc.BLS12_381].rf, params[ecc.BLS12_381].rp, - "seed") + ) var in, out [3]frbls12381.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -167,11 +167,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6633.NewHash( + h := poseidonbw6633.NewPermutation( params[ecc.BW6_633].t, params[ecc.BW6_633].rf, params[ecc.BW6_633].rp, - "seed") + ) var in, out [3]frbw6633.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -200,11 +200,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6633.NewHash( + h := poseidonbw6633.NewPermutation( params[ecc.BW6_633].t, params[ecc.BW6_633].rf, params[ecc.BW6_633].rp, - "seed") + ) var in, out [3]frbw6633.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -233,11 +233,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6761.NewHash( + h := poseidonbw6761.NewPermutation( params[ecc.BW6_761].t, params[ecc.BW6_761].rf, params[ecc.BW6_761].rp, - "seed") + ) var in, out [3]frbw6761.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -266,11 +266,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls24315.NewHash( + h := poseidonbls24315.NewPermutation( params[ecc.BLS24_315].t, params[ecc.BLS24_315].rf, params[ecc.BLS24_315].rp, - "seed") + ) var in, out [3]frbls24315.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -299,11 +299,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls24317.NewHash( + h := poseidonbls24317.NewPermutation( params[ecc.BLS24_317].t, params[ecc.BLS24_317].rf, params[ecc.BLS24_317].rp, - "seed") + ) var in, out [3]frbls24317.Element for i := 0; i < 3; i++ { in[i].SetRandom() From 2473dea0423d0b178cb59ab418fc602e7140181e Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:21:40 -0600 Subject: [PATCH 06/14] refactor: rename NewHash to NewPermutation --- std/hash/poseidon2/posiedon2.go | 2 +- std/permutation/poseidon2/poseidon2.go | 2 +- std/permutation/poseidon2/poseidon2_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/posiedon2.go index cad8cea68..3031955f1 100644 --- a/std/hash/poseidon2/posiedon2.go +++ b/std/hash/poseidon2/posiedon2.go @@ -15,7 +15,7 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { if !ok { return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String()) } - f := poseidon2.NewHash(2, params.d, params.rF, params.rP, curve) + f := poseidon2.NewPoseidon2(2, params.d, params.rF, params.rP, curve) return hash.NewMerkleDamgardHasher(api, &f, 0), nil } diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index ddfdbaa5d..d9341ac1c 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -46,7 +46,7 @@ type parameters struct { roundKeys [][]big.Int } -func NewHash(t, d, rf, rp int, curve ecc.ID) Permutation { +func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { params := parameters{t: t, d: d, rF: rf, rP: rp} if curve == ecc.BN254 { concreteParams := poseidonbn254.NewParameters(t, rf, rp) diff --git a/std/permutation/poseidon2/poseidon2_test.go b/std/permutation/poseidon2/poseidon2_test.go index 66ed81488..861e8b24f 100644 --- a/std/permutation/poseidon2/poseidon2_test.go +++ b/std/permutation/poseidon2/poseidon2_test.go @@ -44,7 +44,7 @@ type circuitParams struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.id) + h := NewPoseidon2(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.id) h.Permutation(api, c.Input) for i := 0; i < len(c.Input); i++ { api.AssertIsEqual(c.Output[i], c.Input[i]) From 254dd75edff3d7cbdd1f6a6d77227f295ce15ff1 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 12 Feb 2025 23:11:55 +0000 Subject: [PATCH 07/14] refactor: align implementation with gnark-crypto --- std/permutation/poseidon2/poseidon2.go | 299 +++++++++++--------- std/permutation/poseidon2/poseidon2_test.go | 32 ++- 2 files changed, 182 insertions(+), 149 deletions(-) diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index d9341ac1c..d19fac288 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -14,6 +14,7 @@ import ( poseidonbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/poseidon2" poseidonbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/poseidon2" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" ) var ( @@ -21,35 +22,51 @@ var ( ) type Permutation struct { + api frontend.API params parameters } // parameters describing the poseidon2 implementation type parameters struct { - // len(preimage)+len(digest)=len(preimage)+ceil(log(2*/r)) - t int + width int // sbox degree - d int + degreeSBox int // number of full rounds (even number) - rF int + nbFullRounds int // number of partial rounds - rP int - - // diagonal elements of the internal matrices, minus one - diagInternalMatrices []big.Int + nbPartialRounds int // round keys roundKeys [][]big.Int } -func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { - params := parameters{t: t, d: d, rF: rf, rP: rp} - if curve == ecc.BN254 { - concreteParams := poseidonbn254.NewParameters(t, rf, rp) +// NewPoseidon2 returns a new Poseidon2 hasher with default parameters as +// defined in the gnark-crypto library. +func NewPoseidon2(api frontend.API) (*Permutation, error) { + switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BLS12_377: + params := poseidonbls12377.NewDefaultParameters() + return NewPoseidon2FromParameters(api, 2, params.NbFullRounds, params.NbPartialRounds) + // TODO: we don't have default parameters for other curves yet. Update this when we do. + default: + return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) + } +} + +// NewPoseidon2FromParameters returns a new Poseidon2 hasher with the given parameters. +// The parameters are used to precompute the round keys. The round key computation +// is deterministic and depends on the curve ID. See the corresponding NewParameters +// function in the gnark-crypto library poseidon2 packages for more details. +func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartialRounds int) (*Permutation, error) { + params := parameters{width: width, nbFullRounds: nbFullRounds, nbPartialRounds: nbPartialRounds} + switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BN254: + params.degreeSBox = poseidonbn254.DegreeSBox() + concreteParams := poseidonbn254.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -57,8 +74,9 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS12_381 { - concreteParams := poseidonbls12381.NewParameters(t, rf, rp) + case ecc.BLS12_381: + params.degreeSBox = poseidonbls12381.DegreeSBox() + concreteParams := poseidonbls12381.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -66,8 +84,9 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS12_377 { - concreteParams := poseidonbls12377.NewParameters(t, rf, rp) + case ecc.BLS12_377: + params.degreeSBox = poseidonbls12377.DegreeSBox() + concreteParams := poseidonbls12377.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -75,8 +94,9 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BW6_761 { - concreteParams := poseidonbw6761.NewParameters(t, rf, rp) + case ecc.BW6_761: + params.degreeSBox = poseidonbw6761.DegreeSBox() + concreteParams := poseidonbw6761.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -84,8 +104,9 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BW6_633 { - concreteParams := poseidonbw6633.NewParameters(t, rf, rp) + case ecc.BW6_633: + params.degreeSBox = poseidonbw6633.DegreeSBox() + concreteParams := poseidonbw6633.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -93,8 +114,9 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS24_315 { - concreteParams := poseidonbls24315.NewParameters(t, rf, rp) + case ecc.BLS24_315: + params.degreeSBox = poseidonbls24315.DegreeSBox() + concreteParams := poseidonbls24315.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -102,8 +124,9 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS24_317 { - concreteParams := poseidonbls24317.NewParameters(t, rf, rp) + case ecc.BLS24_317: + params.degreeSBox = poseidonbls24317.DegreeSBox() + concreteParams := poseidonbls24317.NewParameters(width, nbFullRounds, nbPartialRounds) params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) for i := range params.roundKeys { params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) @@ -111,35 +134,35 @@ func NewPoseidon2(t, d, rf, rp int, curve ecc.ID) Permutation { concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else { - panic(fmt.Errorf("curve %s not supported", curve.String())) + default: + return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) } - return Permutation{params: params} + return &Permutation{api: api, params: params}, nil } // sBox applies the sBox on buffer[index] -func (h *Permutation) sBox(api frontend.API, index int, input []frontend.Variable) { +func (h *Permutation) sBox(index int, input []frontend.Variable) { tmp := input[index] - if h.params.d == 3 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(tmp, input[index]) - } else if h.params.d == 5 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - } else if h.params.d == 7 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - } else if h.params.d == 17 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - } else if h.params.d == -1 { - input[index] = api.Inverse(input[index]) + if h.params.degreeSBox == 3 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(tmp, input[index]) + } else if h.params.degreeSBox == 5 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + } else if h.params.degreeSBox == 7 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + } else if h.params.degreeSBox == 17 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + } else if h.params.degreeSBox == -1 { + input[index] = h.api.Inverse(input[index]) } } @@ -152,21 +175,21 @@ func (h *Permutation) sBox(api frontend.API, index int, input []frontend.Variabl // (1 1 4 6) // on chunks of 4 elements on each part of the buffer // see https://eprint.iacr.org/2023/323.pdf appendix B for the addition chain -func (h *Permutation) matMulM4InPlace(api frontend.API, s []frontend.Variable) { +func (h *Permutation) matMulM4InPlace(s []frontend.Variable) { c := len(s) / 4 for i := 0; i < c; i++ { - t0 := api.Add(s[4*i], s[4*i+1]) // s0+s1 - t1 := api.Add(s[4*i+2], s[4*i+3]) // s2+s3 - t2 := api.Mul(s[4*i+1], 2) - t2 = api.Add(t2, t1) // 2s1+t1 - t3 := api.Mul(s[4*i+3], 2) - t3 = api.Add(t3, t0) // 2s3+t0 - t4 := api.Mul(t1, 4) - t4 = api.Add(t4, t3) // 4t1+t3 - t5 := api.Mul(t0, 4) - t5 = api.Add(t5, t2) // 4t0+t2 - t6 := api.Add(t3, t5) // t3+t5 - t7 := api.Add(t2, t4) // t2+t4 + t0 := h.api.Add(s[4*i], s[4*i+1]) // s0+s1 + t1 := h.api.Add(s[4*i+2], s[4*i+3]) // s2+s3 + t2 := h.api.Mul(s[4*i+1], 2) + t2 = h.api.Add(t2, t1) // 2s1+t1 + t3 := h.api.Mul(s[4*i+3], 2) + t3 = h.api.Add(t3, t0) // 2s3+t0 + t4 := h.api.Mul(t1, 4) + t4 = h.api.Add(t4, t3) // 4t1+t3 + t5 := h.api.Mul(t0, 4) + t5 = h.api.Add(t5, t2) // 4t0+t2 + t6 := h.api.Add(t3, t5) // t3+t5 + t7 := h.api.Add(t2, t4) // t2+t4 s[4*i] = t6 s[4*i+1] = t5 s[4*i+2] = t7 @@ -179,121 +202,129 @@ func (h *Permutation) matMulM4InPlace(api frontend.API, s []frontend.Variable) { // // when t=0[4], the buffer is multiplied by circ(2M4,M4,..,M4) // see https://eprint.iacr.org/2023/323.pdf -func (h *Permutation) matMulExternalInPlace(api frontend.API, input []frontend.Variable) { +func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { - if h.params.t == 2 { - tmp := api.Add(input[0], input[1]) - input[0] = api.Add(tmp, input[0]) - input[1] = api.Add(tmp, input[1]) - } else if h.params.t == 3 { - var tmp frontend.Variable - tmp = api.Add(input[0], input[1]) - tmp = api.Add(tmp, input[2]) - input[0] = api.Add(input[0], tmp) - input[1] = api.Add(input[1], tmp) - input[2] = api.Add(input[2], tmp) - } else if h.params.t == 4 { - h.matMulM4InPlace(api, input) + if h.params.width == 2 { + tmp := h.api.Add(input[0], input[1]) + input[0] = h.api.Add(tmp, input[0]) + input[1] = h.api.Add(tmp, input[1]) + } else if h.params.width == 3 { + tmp := h.api.Add(input[0], input[1]) + tmp = h.api.Add(tmp, input[2]) + input[0] = h.api.Add(input[0], tmp) + input[1] = h.api.Add(input[1], tmp) + input[2] = h.api.Add(input[2], tmp) + } else if h.params.width == 4 { + h.matMulM4InPlace(input) } else { // at this stage t is supposed to be a multiple of 4 // the MDS matrix is circ(2M4,M4,..,M4) - h.matMulM4InPlace(api, input) + h.matMulM4InPlace(input) tmp := make([]frontend.Variable, 4) - for i := 0; i < h.params.t/4; i++ { - tmp[0] = api.Add(tmp[0], input[4*i]) - tmp[1] = api.Add(tmp[1], input[4*i+1]) - tmp[2] = api.Add(tmp[2], input[4*i+2]) - tmp[3] = api.Add(tmp[3], input[4*i+3]) + for i := 0; i < h.params.width/4; i++ { + tmp[0] = h.api.Add(tmp[0], input[4*i]) + tmp[1] = h.api.Add(tmp[1], input[4*i+1]) + tmp[2] = h.api.Add(tmp[2], input[4*i+2]) + tmp[3] = h.api.Add(tmp[3], input[4*i+3]) } - for i := 0; i < h.params.t/4; i++ { - input[4*i] = api.Add(input[4*i], tmp[0]) - input[4*i+1] = api.Add(input[4*i], tmp[1]) - input[4*i+2] = api.Add(input[4*i], tmp[2]) - input[4*i+3] = api.Add(input[4*i], tmp[3]) + for i := 0; i < h.params.width/4; i++ { + input[4*i] = h.api.Add(input[4*i], tmp[0]) + input[4*i+1] = h.api.Add(input[4*i], tmp[1]) + input[4*i+2] = h.api.Add(input[4*i], tmp[2]) + input[4*i+3] = h.api.Add(input[4*i], tmp[3]) } } } // when t=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, -func (h *Permutation) matMulInternalInPlace(api frontend.API, input []frontend.Variable) { - if h.params.t == 2 { - sum := api.Add(input[0], input[1]) - input[0] = api.Add(input[0], sum) - input[1] = api.Mul(2, input[1]) - input[1] = api.Add(input[1], sum) - } else if h.params.t == 3 { - var sum frontend.Variable - sum = api.Add(input[0], input[1]) - sum = api.Add(sum, input[2]) - input[0] = api.Add(input[0], sum) - input[1] = api.Add(input[1], sum) - input[2] = api.Mul(input[2], 2) - input[2] = api.Add(input[2], sum) +func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { + if h.params.width == 2 { + sum := h.api.Add(input[0], input[1]) + input[0] = h.api.Add(input[0], sum) + input[1] = h.api.Mul(2, input[1]) + input[1] = h.api.Add(input[1], sum) + } else if h.params.width == 3 { + sum := h.api.Add(input[0], input[1]) + sum = h.api.Add(sum, input[2]) + input[0] = h.api.Add(input[0], sum) + input[1] = h.api.Add(input[1], sum) + input[2] = h.api.Mul(input[2], 2) + input[2] = h.api.Add(input[2], sum) } else { - var sum frontend.Variable - sum = input[0] - for i := 1; i < h.params.t; i++ { - sum = api.Add(sum, input[i]) - } - for i := 0; i < h.params.t; i++ { - input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) - input[i] = api.Add(input[i], sum) - } + // TODO: we don't have general case implemented in gnark-crypto side. + // Currently we only have the hardcoded matrices for t=2,3. If we would + // use `h.params.diagInternalMatrices` we would need to set it, but + // currently they are nil. + + // var sum frontend.Variable + // sum = input[0] + // for i := 1; i < h.params.width; i++ { + // sum = api.Add(sum, input[i]) + // } + // for i := 0; i < h.params.width; i++ { + // input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) + // input[i] = api.Add(input[i], sum) + // } } } // addRoundKeyInPlace adds the round-th key to the buffer -func (h *Permutation) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) { +func (h *Permutation) addRoundKeyInPlace(round int, input []frontend.Variable) { for i := 0; i < len(h.params.roundKeys[round]); i++ { - input[i] = api.Add(input[i], h.params.roundKeys[round][i]) + input[i] = h.api.Add(input[i], h.params.roundKeys[round][i]) } } -func (h *Permutation) Permutation(api frontend.API, input []frontend.Variable) error { - if len(input) != h.params.t { +// Permutation applies the permutation on input, and stores the result in input. +func (h *Permutation) Permutation(input []frontend.Variable) error { + if len(input) != h.params.width { return ErrInvalidSizebuffer } // external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6) - h.matMulExternalInPlace(api, input) + h.matMulExternalInPlace(input) - rf := h.params.rF / 2 + rf := h.params.nbFullRounds / 2 for i := 0; i < rf; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) - h.addRoundKeyInPlace(api, i, input) - for j := 0; j < h.params.t; j++ { - h.sBox(api, j, input) + h.addRoundKeyInPlace(i, input) + for j := 0; j < h.params.width; j++ { + h.sBox(j, input) } - h.matMulExternalInPlace(api, input) + h.matMulExternalInPlace(input) } - for i := rf; i < rf+h.params.rP; i++ { + for i := rf; i < rf+h.params.nbPartialRounds; i++ { // one round = matMulInternal(sBox_sparse(addRoundKey)) - h.addRoundKeyInPlace(api, i, input) - h.sBox(api, 0, input) - h.matMulInternalInPlace(api, input) + h.addRoundKeyInPlace(i, input) + h.sBox(0, input) + h.matMulInternalInPlace(input) } - for i := rf + h.params.rP; i < h.params.rF+h.params.rP; i++ { + for i := rf + h.params.nbPartialRounds; i < h.params.nbFullRounds+h.params.nbPartialRounds; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) - h.addRoundKeyInPlace(api, i, input) - for j := 0; j < h.params.t; j++ { - h.sBox(api, j, input) + h.addRoundKeyInPlace(i, input) + for j := 0; j < h.params.width; j++ { + h.sBox(j, input) } - h.matMulExternalInPlace(api, input) + h.matMulExternalInPlace(input) } return nil } -// Compress aliases Permutation in the t=2 case -// implements hash.CompressionFunction -func (h *Permutation) Compress(api frontend.API, l, r frontend.Variable) frontend.Variable { - if h.params.t != 2 { +// Compress applies the permutation on left and right and returns the right lane +// of the result. Panics if the permutation instance is not initialized with a +// width of 2. +// +// Implements the [hash.Compressor] interface for building a Merkle-Damgard +// hash construction. +func (h *Permutation) Compress(left, right frontend.Variable) frontend.Variable { + if h.params.width != 2 { panic("poseidon2: Compress can only be used when t=2") } - vars := [2]frontend.Variable{l, r} - if err := h.Permutation(api, vars[:]); err != nil { + vars := [2]frontend.Variable{left, right} + if err := h.Permutation(vars[:]); err != nil { panic(err) // this would never happen } return vars[1] diff --git a/std/permutation/poseidon2/poseidon2_test.go b/std/permutation/poseidon2/poseidon2_test.go index 861e8b24f..ed0a3b807 100644 --- a/std/permutation/poseidon2/poseidon2_test.go +++ b/std/permutation/poseidon2/poseidon2_test.go @@ -1,6 +1,7 @@ package poseidon import ( + "fmt" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -35,17 +36,18 @@ type Poseidon2Circuit struct { } type circuitParams struct { - seed string - rf int - rp int - t int - d int - id ecc.ID + rf int + rp int + t int + id ecc.ID } func (c *Poseidon2Circuit) Define(api frontend.API) error { - h := NewPoseidon2(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.id) - h.Permutation(api, c.Input) + h, err := NewPoseidon2FromParameters(api, c.params.t, c.params.rf, c.params.rp) + if err != nil { + return fmt.Errorf("could not create poseidon2 hasher: %w", err) + } + h.Permutation(c.Input) for i := 0; i < len(c.Input); i++ { api.AssertIsEqual(c.Output[i], c.Input[i]) } @@ -57,13 +59,13 @@ func TestPoseidon2(t *testing.T) { assert := test.NewAssert(t) params := make(map[ecc.ID]circuitParams) - params[ecc.BN254] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BN254} - params[ecc.BLS12_381] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BLS12_381} - params[ecc.BLS12_377] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 17, id: ecc.BLS12_377} - params[ecc.BW6_761] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BW6_761} - params[ecc.BW6_633] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BW6_633} - params[ecc.BLS24_315] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BLS24_315} - params[ecc.BLS24_317] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 7, id: ecc.BLS24_317} + params[ecc.BN254] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BN254} + params[ecc.BLS12_381] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS12_381} + params[ecc.BLS12_377] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS12_377} + params[ecc.BW6_761] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BW6_761} + params[ecc.BW6_633] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BW6_633} + params[ecc.BLS24_315] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS24_315} + params[ecc.BLS24_317] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS24_317} { var circuit, validWitness Poseidon2Circuit From cc342a0399b71240099cc18c9a93213d97da2e60 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 12 Feb 2025 23:12:42 +0000 Subject: [PATCH 08/14] refactor: align naming with gnark-crypto --- std/hash/hash.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/std/hash/hash.go b/std/hash/hash.go index 66452a797..c1f6fe135 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -88,21 +88,29 @@ type BinaryFixedLengthHasher interface { FixedLengthSum(length frontend.Variable) []uints.U8 } -// CompressionFunction is a 2 to 1 function -type CompressionFunction interface { - Compress(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable +// Compressor is a 2-1 one-way function. It takes two inputs and compresses +// them into one output. +// +// NB! This is lossy compression, meaning that the output is not guaranteed to +// be unique for different inputs. The output is guaranteed to be the same for +// the same inputs. +// +// The Compressor is used in the Merkle-Damgard construction to build a hash +// function. +type Compressor interface { + Compress(frontend.Variable, frontend.Variable) frontend.Variable } type merkleDamgardHasher struct { state frontend.Variable iv frontend.Variable - f CompressionFunction + f Compressor api frontend.API } // NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash // initialState is a value whose preimage is not known -func NewMerkleDamgardHasher(api frontend.API, f CompressionFunction, initialState frontend.Variable) FieldHasher { +func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) FieldHasher { return &merkleDamgardHasher{ state: initialState, iv: initialState, @@ -117,7 +125,7 @@ func (h *merkleDamgardHasher) Reset() { func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { for _, d := range data { - h.state = h.f.Compress(h.api, h.state, d) + h.state = h.f.Compress(h.state, d) } } From 4b9aaf2fdec53a3d03c1e574ad092be65cd7a8fb Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 12 Feb 2025 23:16:11 +0000 Subject: [PATCH 09/14] feat: take default parameters from gnark-crypto --- std/hash/poseidon2/posiedon2.go | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/posiedon2.go index 3031955f1..6d5ce9735 100644 --- a/std/hash/poseidon2/posiedon2.go +++ b/std/hash/poseidon2/posiedon2.go @@ -2,31 +2,16 @@ package poseidon2 import ( "fmt" - "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/hash" poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" ) func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { - curve := utils.FieldToCurve(api.Compiler().Field()) - params, ok := parameters[curve] - if !ok { - return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String()) + f, err := poseidon2.NewPoseidon2(api) + if err != nil { + return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) } - f := poseidon2.NewPoseidon2(2, params.d, params.rF, params.rP, curve) - return hash.NewMerkleDamgardHasher(api, &f, 0), nil -} - -var parameters = map[ecc.ID]struct { - d int - rF int - rP int -}{ - ecc.BLS12_377: { - rF: 6, - rP: 26, - d: 17, - }, + return hash.NewMerkleDamgardHasher(api, f, 0), nil } From ffdba7f808616d7d9a593867023b1e326363d429 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 12 Feb 2025 23:34:12 +0000 Subject: [PATCH 10/14] chore: rename file --- std/hash/poseidon2/{posiedon2.go => poseidon2.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename std/hash/poseidon2/{posiedon2.go => poseidon2.go} (100%) diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/poseidon2.go similarity index 100% rename from std/hash/poseidon2/posiedon2.go rename to std/hash/poseidon2/poseidon2.go From d6fcd5d92e9239b265865cf477f43eebeb50e52e Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 12 Feb 2025 23:34:43 +0000 Subject: [PATCH 11/14] refactor: rename constructor to align with gnark-crypto --- std/hash/poseidon2/poseidon2.go | 4 +++- std/hash/poseidon2/poseidon2_test.go | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index 6d5ce9735..d0427d2ca 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -8,7 +8,9 @@ import ( poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" ) -func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { +// NewMerkleDamgardHasher returns a Poseidon2 hasher using the Merkle-Damgard +// construction with the default parameters. +func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 72293926f..307b473fb 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,12 +1,13 @@ package poseidon2 import ( + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" - "testing" ) func TestPoseidon2Hash(t *testing.T) { @@ -19,7 +20,7 @@ func TestPoseidon2Hash(t *testing.T) { res := h.Sum(nil) test.SingleFunction(ecc.BLS12_377, func(api frontend.API) []frontend.Variable { - hsh, err := NewPoseidon2(api) + hsh, err := NewMerkleDamgardHasher(api) require.NoError(t, err) hsh.Write(0, 1, 2, 3, 4) return []frontend.Variable{hsh.Sum()} From 5c46d09b1d62dbbd45a38bc1d0fe8afe82de3065 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Thu, 13 Feb 2025 00:32:59 +0000 Subject: [PATCH 12/14] fix: add danger when calling permute for t=2,3 for now --- std/permutation/poseidon2/poseidon2.go | 1 + 1 file changed, 1 insertion(+) diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index d19fac288..02a90aeb8 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -266,6 +266,7 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) // input[i] = api.Add(input[i], sum) // } + panic("only T=2,3 is supported") } } From d40d8d9afedf5ce05cdc2d33f19cb06ed2534113 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Fri, 14 Feb 2025 16:31:33 -0600 Subject: [PATCH 13/14] refactor: mature SingleFunctionTest --- std/hash/poseidon2/poseidon2_test.go | 9 +++--- test/quick.go | 47 ++++++++++++++-------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 307b473fb..70a13d820 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,9 +1,9 @@ package poseidon2 import ( + "github.com/consensys/gnark-crypto/ecc" "testing" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" @@ -19,10 +19,11 @@ func TestPoseidon2Hash(t *testing.T) { } res := h.Sum(nil) - test.SingleFunction(ecc.BLS12_377, func(api frontend.API) []frontend.Variable { + test.Function(func(api frontend.API) error { hsh, err := NewMerkleDamgardHasher(api) require.NoError(t, err) hsh.Write(0, 1, 2, 3, 4) - return []frontend.Variable{hsh.Sum()} - }, res)(t) + api.AssertIsEqual(hsh.Sum(), res) + return nil + }, test.WithCurves(ecc.BLS12_377))(t) } diff --git a/test/quick.go b/test/quick.go index 988103f5c..457245404 100644 --- a/test/quick.go +++ b/test/quick.go @@ -4,48 +4,47 @@ import ( "crypto/rand" "encoding/binary" "errors" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/stretchr/testify/require" + "sync" "testing" ) -var snarkFunctionStore = make(map[uint64]func(frontend.API) []frontend.Variable) // todo make thread safe +var snarkFunctionStore sync.Map + type snarkFunctionTestCircuit struct { - Outs []frontend.Variable - funcId uint64 // this workaround is necessary because deepEquals fails on objects with function fields + funcId uint64 // this workaround is necessary because deepEquals fails on objects with function fields + DummyInput frontend.Variable // to keep the Plonk backend from crashing } func (c *snarkFunctionTestCircuit) Define(api frontend.API) error { - outs := snarkFunctionStore[c.funcId](api) - delete(snarkFunctionStore, c.funcId) - // todo replace with SliceEquals - if len(outs) != len(c.Outs) { - return errors.New("SingleFunction: unexpected number of output") + f, ok := snarkFunctionStore.Load(c.funcId) + if !ok { + return errors.New("function not found") } - for i := range outs { - api.AssertIsEqual(outs[i], c.Outs[i]) + + F, ok := f.(func(frontend.API) error) + if !ok { + panic("unexpected entry type") } - return nil -} -// SingleFunction returns a test function that can run a simple circuit consisting of function f, and match its output with outs -func SingleFunction(curve ecc.ID, f func(frontend.API) []frontend.Variable, outs ...frontend.Variable) func(*testing.T) { + return F(api) +} +// Function returns a test function that can run a simple circuit consisting of function f +func Function(f func(frontend.API) error, opts ...TestingOption) func(*testing.T) { return func(t *testing.T) { - c := snarkFunctionTestCircuit{ - Outs: make([]frontend.Variable, len(outs)), - } - var b [8]byte + var ( + c snarkFunctionTestCircuit + b [8]byte + ) _, err := rand.Read(b[:]) require.NoError(t, err) c.funcId = binary.BigEndian.Uint64(b[:]) - snarkFunctionStore[c.funcId] = f + snarkFunctionStore.Store(c.funcId, f) - a := snarkFunctionTestCircuit{ - Outs: outs, - } - require.NoError(t, IsSolved(&c, &a, curve.ScalarField())) + NewAssert(t).SolvingSucceeded(&c, &snarkFunctionTestCircuit{DummyInput: 0}, opts...) + snarkFunctionStore.Delete(c.funcId) } } From 936b2656849fce9f9338d4affe33aa60da595d30 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 17 Feb 2025 00:32:49 +0000 Subject: [PATCH 14/14] test: use assert.CheckCircuit directly Until https://github.com/Consensys/gnark/pull/1422 is resolved --- std/hash/poseidon2/poseidon2_test.go | 36 +++++++++++++------- test/quick.go | 50 ---------------------------- 2 files changed, 24 insertions(+), 62 deletions(-) delete mode 100644 test/quick.go diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 70a13d820..1ce1d46fe 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,29 +1,41 @@ package poseidon2 import ( - "github.com/consensys/gnark-crypto/ecc" "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" - "github.com/stretchr/testify/require" ) +type Poseidon2Circuit struct { + Input []frontend.Variable + Expected frontend.Variable `gnark:",public"` +} + +func (c *Poseidon2Circuit) Define(api frontend.API) error { + hsh, err := NewMerkleDamgardHasher(api) + if err != nil { + return err + } + hsh.Write(c.Input...) + api.AssertIsEqual(hsh.Sum(), c.Expected) + return nil +} + func TestPoseidon2Hash(t *testing.T) { + assert := test.NewAssert(t) + + const nbInputs = 5 // prepare expected output h := poseidon2.NewMerkleDamgardHasher() - for i := range 5 { + circInput := make([]frontend.Variable, nbInputs) + for i := range nbInputs { _, err := h.Write([]byte{byte(i)}) - require.NoError(t, err) + assert.NoError(err) + circInput[i] = i } res := h.Sum(nil) - - test.Function(func(api frontend.API) error { - hsh, err := NewMerkleDamgardHasher(api) - require.NoError(t, err) - hsh.Write(0, 1, 2, 3, 4) - api.AssertIsEqual(hsh.Sum(), res) - return nil - }, test.WithCurves(ecc.BLS12_377))(t) + assert.CheckCircuit(&Poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&Poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 } diff --git a/test/quick.go b/test/quick.go deleted file mode 100644 index 457245404..000000000 --- a/test/quick.go +++ /dev/null @@ -1,50 +0,0 @@ -package test - -import ( - "crypto/rand" - "encoding/binary" - "errors" - "github.com/consensys/gnark/frontend" - "github.com/stretchr/testify/require" - "sync" - "testing" -) - -var snarkFunctionStore sync.Map - -type snarkFunctionTestCircuit struct { - funcId uint64 // this workaround is necessary because deepEquals fails on objects with function fields - DummyInput frontend.Variable // to keep the Plonk backend from crashing -} - -func (c *snarkFunctionTestCircuit) Define(api frontend.API) error { - - f, ok := snarkFunctionStore.Load(c.funcId) - if !ok { - return errors.New("function not found") - } - - F, ok := f.(func(frontend.API) error) - if !ok { - panic("unexpected entry type") - } - - return F(api) -} - -// Function returns a test function that can run a simple circuit consisting of function f -func Function(f func(frontend.API) error, opts ...TestingOption) func(*testing.T) { - return func(t *testing.T) { - var ( - c snarkFunctionTestCircuit - b [8]byte - ) - _, err := rand.Read(b[:]) - require.NoError(t, err) - c.funcId = binary.BigEndian.Uint64(b[:]) - snarkFunctionStore.Store(c.funcId, f) - - NewAssert(t).SolvingSucceeded(&c, &snarkFunctionTestCircuit{DummyInput: 0}, opts...) - snarkFunctionStore.Delete(c.funcId) - } -}