diff --git a/go.mod b/go.mod index dbeef7830..b7b70697e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.27 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.15.0 + github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 5f24bbd3a..200c8c71e 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.27 h1:j6hKUrGAy/H+gpNrpLU3I26n1yc+VMGmd6ID5+gAh github.com/consensys/bavard v0.1.27/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.15.0 h1:OXsWnhheHV59eXIzhL5OIexa/vqTK8wtRYQCtwfMDtY= -github.com/consensys/gnark-crypto v0.15.0/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= +github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca h1:u6iXwMBfbXODF+hDSwKSTBg6yfD3+eMX6o3PILAK474= +github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/hash/hash.go b/std/hash/hash.go index 80a1e5456..c1f6fe135 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -5,7 +5,7 @@ package hash import ( - "errors" + "fmt" "sync" "github.com/consensys/gnark/frontend" @@ -58,7 +58,7 @@ func GetFieldHasher(name string, api frontend.API) (FieldHasher, error) { defer lock.RUnlock() builder, ok := builderRegistry[name] if !ok { - return nil, errors.New("hash function not found") + return nil, fmt.Errorf("hash function \"%s\" not registered", name) } return builder(api) } @@ -87,3 +87,48 @@ type BinaryFixedLengthHasher interface { // FixedLengthSum returns digest of the first length bytes. FixedLengthSum(length frontend.Variable) []uints.U8 } + +// Compressor is a 2-1 one-way function. It takes two inputs and compresses +// them into one output. +// +// NB! This is lossy compression, meaning that the output is not guaranteed to +// be unique for different inputs. The output is guaranteed to be the same for +// the same inputs. +// +// The Compressor is used in the Merkle-Damgard construction to build a hash +// function. +type Compressor interface { + Compress(frontend.Variable, frontend.Variable) frontend.Variable +} + +type merkleDamgardHasher struct { + state frontend.Variable + iv frontend.Variable + f Compressor + api frontend.API +} + +// NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash +// initialState is a value whose preimage is not known +func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) FieldHasher { + return &merkleDamgardHasher{ + state: initialState, + iv: initialState, + f: f, + api: api, + } +} + +func (h *merkleDamgardHasher) Reset() { + h.state = h.iv +} + +func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { + for _, d := range data { + h.state = h.f.Compress(h.state, d) + } +} + +func (h *merkleDamgardHasher) Sum() frontend.Variable { + return h.state +} diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go new file mode 100644 index 000000000..d0427d2ca --- /dev/null +++ b/std/hash/poseidon2/poseidon2.go @@ -0,0 +1,19 @@ +package poseidon2 + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" +) + +// NewMerkleDamgardHasher returns a Poseidon2 hasher using the Merkle-Damgard +// construction with the default parameters. +func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { + f, err := poseidon2.NewPoseidon2(api) + if err != nil { + return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go new file mode 100644 index 000000000..1ce1d46fe --- /dev/null +++ b/std/hash/poseidon2/poseidon2_test.go @@ -0,0 +1,41 @@ +package poseidon2 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type Poseidon2Circuit struct { + Input []frontend.Variable + Expected frontend.Variable `gnark:",public"` +} + +func (c *Poseidon2Circuit) Define(api frontend.API) error { + hsh, err := NewMerkleDamgardHasher(api) + if err != nil { + return err + } + hsh.Write(c.Input...) + api.AssertIsEqual(hsh.Sum(), c.Expected) + return nil +} + +func TestPoseidon2Hash(t *testing.T) { + assert := test.NewAssert(t) + + const nbInputs = 5 + // prepare expected output + h := poseidon2.NewMerkleDamgardHasher() + circInput := make([]frontend.Variable, nbInputs) + for i := range nbInputs { + _, err := h.Write([]byte{byte(i)}) + assert.NoError(err) + circInput[i] = i + } + res := h.Sum(nil) + assert.CheckCircuit(&Poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&Poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 +} diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 00fb9306f..02a90aeb8 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -2,6 +2,7 @@ package poseidon import ( "errors" + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -13,130 +14,155 @@ import ( poseidonbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/poseidon2" poseidonbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/poseidon2" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" ) var ( ErrInvalidSizebuffer = errors.New("the size of the input should match the size of the hash buffer") ) -type Hash struct { +type Permutation struct { + api frontend.API params parameters } // parameters describing the poseidon2 implementation type parameters struct { - // len(preimage)+len(digest)=len(preimage)+ceil(log(2*/r)) - t int + width int // sbox degree - d int + degreeSBox int // number of full rounds (even number) - rF int + nbFullRounds int // number of partial rounds - rP int - - // diagonal elements of the internal matrices, minus one - diagInternalMatrices []big.Int + nbPartialRounds int // round keys roundKeys [][]big.Int } -func NewHash(t, d, rf, rp int, seed string, curve ecc.ID) Hash { - params := parameters{t: t, d: d, rF: rf, rP: rp} - if curve == ecc.BN254 { - rc := poseidonbn254.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) +// NewPoseidon2 returns a new Poseidon2 hasher with default parameters as +// defined in the gnark-crypto library. +func NewPoseidon2(api frontend.API) (*Permutation, error) { + switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BLS12_377: + params := poseidonbls12377.NewDefaultParameters() + return NewPoseidon2FromParameters(api, 2, params.NbFullRounds, params.NbPartialRounds) + // TODO: we don't have default parameters for other curves yet. Update this when we do. + default: + return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) + } +} + +// NewPoseidon2FromParameters returns a new Poseidon2 hasher with the given parameters. +// The parameters are used to precompute the round keys. The round key computation +// is deterministic and depends on the curve ID. See the corresponding NewParameters +// function in the gnark-crypto library poseidon2 packages for more details. +func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartialRounds int) (*Permutation, error) { + params := parameters{width: width, nbFullRounds: nbFullRounds, nbPartialRounds: nbPartialRounds} + switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BN254: + params.degreeSBox = poseidonbn254.DegreeSBox() + concreteParams := poseidonbn254.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS12_381 { - rc := poseidonbls12381.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + case ecc.BLS12_381: + params.degreeSBox = poseidonbls12381.DegreeSBox() + concreteParams := poseidonbls12381.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS12_377 { - rc := poseidonbls12377.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + case ecc.BLS12_377: + params.degreeSBox = poseidonbls12377.DegreeSBox() + concreteParams := poseidonbls12377.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BW6_761 { - rc := poseidonbw6761.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + case ecc.BW6_761: + params.degreeSBox = poseidonbw6761.DegreeSBox() + concreteParams := poseidonbw6761.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BW6_633 { - rc := poseidonbw6633.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + case ecc.BW6_633: + params.degreeSBox = poseidonbw6633.DegreeSBox() + concreteParams := poseidonbw6633.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS24_315 { - rc := poseidonbls24315.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + case ecc.BLS24_315: + params.degreeSBox = poseidonbls24315.DegreeSBox() + concreteParams := poseidonbls24315.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } - } else if curve == ecc.BLS24_317 { - rc := poseidonbls24317.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + case ecc.BLS24_317: + params.degreeSBox = poseidonbls24317.DegreeSBox() + concreteParams := poseidonbls24317.NewParameters(width, nbFullRounds, nbPartialRounds) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } + default: + return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) } - return Hash{params: params} + return &Permutation{api: api, params: params}, nil } // sBox applies the sBox on buffer[index] -func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) { +func (h *Permutation) sBox(index int, input []frontend.Variable) { tmp := input[index] - if h.params.d == 3 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(tmp, input[index]) - } else if h.params.d == 5 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - } else if h.params.d == 7 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - } else if h.params.d == 17 { - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], input[index]) - input[index] = api.Mul(input[index], tmp) - } else if h.params.d == -1 { - input[index] = api.Inverse(input[index]) + if h.params.degreeSBox == 3 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(tmp, input[index]) + } else if h.params.degreeSBox == 5 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + } else if h.params.degreeSBox == 7 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + } else if h.params.degreeSBox == 17 { + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], input[index]) + input[index] = h.api.Mul(input[index], tmp) + } else if h.params.degreeSBox == -1 { + input[index] = h.api.Inverse(input[index]) } } @@ -149,21 +175,21 @@ func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) { // (1 1 4 6) // on chunks of 4 elements on each part of the buffer // see https://eprint.iacr.org/2023/323.pdf appendix B for the addition chain -func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) { +func (h *Permutation) matMulM4InPlace(s []frontend.Variable) { c := len(s) / 4 for i := 0; i < c; i++ { - t0 := api.Add(s[4*i], s[4*i+1]) // s0+s1 - t1 := api.Add(s[4*i+2], s[4*i+3]) // s2+s3 - t2 := api.Mul(s[4*i+1], 2) - t2 = api.Add(t2, t1) // 2s1+t1 - t3 := api.Mul(s[4*i+3], 2) - t3 = api.Add(t3, t0) // 2s3+t0 - t4 := api.Mul(t1, 4) - t4 = api.Add(t4, t3) // 4t1+t3 - t5 := api.Mul(t0, 4) - t5 = api.Add(t5, t2) // 4t0+t2 - t6 := api.Add(t3, t5) // t3+t5 - t7 := api.Add(t2, t4) // t2+t4 + t0 := h.api.Add(s[4*i], s[4*i+1]) // s0+s1 + t1 := h.api.Add(s[4*i+2], s[4*i+3]) // s2+s3 + t2 := h.api.Mul(s[4*i+1], 2) + t2 = h.api.Add(t2, t1) // 2s1+t1 + t3 := h.api.Mul(s[4*i+3], 2) + t3 = h.api.Add(t3, t0) // 2s3+t0 + t4 := h.api.Mul(t1, 4) + t4 = h.api.Add(t4, t3) // 4t1+t3 + t5 := h.api.Mul(t0, 4) + t5 = h.api.Add(t5, t2) // 4t0+t2 + t6 := h.api.Add(t3, t5) // t3+t5 + t7 := h.api.Add(t2, t4) // t2+t4 s[4*i] = t6 s[4*i+1] = t5 s[4*i+2] = t7 @@ -176,109 +202,131 @@ func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) { // // when t=0[4], the buffer is multiplied by circ(2M4,M4,..,M4) // see https://eprint.iacr.org/2023/323.pdf -func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable) { +func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { - if h.params.t == 2 { - tmp := api.Add(input[0], input[1]) - input[0] = api.Add(tmp, input[0]) - input[1] = api.Add(tmp, input[1]) - } else if h.params.t == 3 { - var tmp frontend.Variable - tmp = api.Add(input[0], input[1]) - tmp = api.Add(tmp, input[2]) - input[0] = api.Add(input[0], tmp) - input[1] = api.Add(input[1], tmp) - input[2] = api.Add(input[2], tmp) - } else if h.params.t == 4 { - h.matMulM4InPlace(api, input) + if h.params.width == 2 { + tmp := h.api.Add(input[0], input[1]) + input[0] = h.api.Add(tmp, input[0]) + input[1] = h.api.Add(tmp, input[1]) + } else if h.params.width == 3 { + tmp := h.api.Add(input[0], input[1]) + tmp = h.api.Add(tmp, input[2]) + input[0] = h.api.Add(input[0], tmp) + input[1] = h.api.Add(input[1], tmp) + input[2] = h.api.Add(input[2], tmp) + } else if h.params.width == 4 { + h.matMulM4InPlace(input) } else { // at this stage t is supposed to be a multiple of 4 // the MDS matrix is circ(2M4,M4,..,M4) - h.matMulM4InPlace(api, input) + h.matMulM4InPlace(input) tmp := make([]frontend.Variable, 4) - for i := 0; i < h.params.t/4; i++ { - tmp[0] = api.Add(tmp[0], input[4*i]) - tmp[1] = api.Add(tmp[1], input[4*i+1]) - tmp[2] = api.Add(tmp[2], input[4*i+2]) - tmp[3] = api.Add(tmp[3], input[4*i+3]) + for i := 0; i < h.params.width/4; i++ { + tmp[0] = h.api.Add(tmp[0], input[4*i]) + tmp[1] = h.api.Add(tmp[1], input[4*i+1]) + tmp[2] = h.api.Add(tmp[2], input[4*i+2]) + tmp[3] = h.api.Add(tmp[3], input[4*i+3]) } - for i := 0; i < h.params.t/4; i++ { - input[4*i] = api.Add(input[4*i], tmp[0]) - input[4*i+1] = api.Add(input[4*i], tmp[1]) - input[4*i+2] = api.Add(input[4*i], tmp[2]) - input[4*i+3] = api.Add(input[4*i], tmp[3]) + for i := 0; i < h.params.width/4; i++ { + input[4*i] = h.api.Add(input[4*i], tmp[0]) + input[4*i+1] = h.api.Add(input[4*i], tmp[1]) + input[4*i+2] = h.api.Add(input[4*i], tmp[2]) + input[4*i+3] = h.api.Add(input[4*i], tmp[3]) } } } // when t=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, -func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable) { - if h.params.t == 2 { - sum := api.Add(input[0], input[1]) - input[0] = api.Add(input[0], sum) - input[1] = api.Mul(2, input[1]) - input[1] = api.Add(input[1], sum) - } else if h.params.t == 3 { - var sum frontend.Variable - sum = api.Add(input[0], input[1]) - sum = api.Add(sum, input[2]) - input[0] = api.Add(input[0], sum) - input[1] = api.Add(input[1], sum) - input[2] = api.Mul(input[2], 2) - input[2] = api.Add(input[2], sum) +func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { + if h.params.width == 2 { + sum := h.api.Add(input[0], input[1]) + input[0] = h.api.Add(input[0], sum) + input[1] = h.api.Mul(2, input[1]) + input[1] = h.api.Add(input[1], sum) + } else if h.params.width == 3 { + sum := h.api.Add(input[0], input[1]) + sum = h.api.Add(sum, input[2]) + input[0] = h.api.Add(input[0], sum) + input[1] = h.api.Add(input[1], sum) + input[2] = h.api.Mul(input[2], 2) + input[2] = h.api.Add(input[2], sum) } else { - var sum frontend.Variable - sum = input[0] - for i := 1; i < h.params.t; i++ { - sum = api.Add(sum, input[i]) - } - for i := 0; i < h.params.t; i++ { - input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) - input[i] = api.Add(input[i], sum) - } + // TODO: we don't have general case implemented in gnark-crypto side. + // Currently we only have the hardcoded matrices for t=2,3. If we would + // use `h.params.diagInternalMatrices` we would need to set it, but + // currently they are nil. + + // var sum frontend.Variable + // sum = input[0] + // for i := 1; i < h.params.width; i++ { + // sum = api.Add(sum, input[i]) + // } + // for i := 0; i < h.params.width; i++ { + // input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) + // input[i] = api.Add(input[i], sum) + // } + panic("only T=2,3 is supported") } } // addRoundKeyInPlace adds the round-th key to the buffer -func (h *Hash) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) { +func (h *Permutation) addRoundKeyInPlace(round int, input []frontend.Variable) { for i := 0; i < len(h.params.roundKeys[round]); i++ { - input[i] = api.Add(input[i], h.params.roundKeys[round][i]) + input[i] = h.api.Add(input[i], h.params.roundKeys[round][i]) } } -func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error { - if len(input) != h.params.t { +// Permutation applies the permutation on input, and stores the result in input. +func (h *Permutation) Permutation(input []frontend.Variable) error { + if len(input) != h.params.width { return ErrInvalidSizebuffer } // external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6) - h.matMulExternalInPlace(api, input) + h.matMulExternalInPlace(input) - rf := h.params.rF / 2 + rf := h.params.nbFullRounds / 2 for i := 0; i < rf; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) - h.addRoundKeyInPlace(api, i, input) - for j := 0; j < h.params.t; j++ { - h.sBox(api, j, input) + h.addRoundKeyInPlace(i, input) + for j := 0; j < h.params.width; j++ { + h.sBox(j, input) } - h.matMulExternalInPlace(api, input) + h.matMulExternalInPlace(input) } - for i := rf; i < rf+h.params.rP; i++ { + for i := rf; i < rf+h.params.nbPartialRounds; i++ { // one round = matMulInternal(sBox_sparse(addRoundKey)) - h.addRoundKeyInPlace(api, i, input) - h.sBox(api, 0, input) - h.matMulInternalInPlace(api, input) + h.addRoundKeyInPlace(i, input) + h.sBox(0, input) + h.matMulInternalInPlace(input) } - for i := rf + h.params.rP; i < h.params.rF+h.params.rP; i++ { + for i := rf + h.params.nbPartialRounds; i < h.params.nbFullRounds+h.params.nbPartialRounds; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) - h.addRoundKeyInPlace(api, i, input) - for j := 0; j < h.params.t; j++ { - h.sBox(api, j, input) + h.addRoundKeyInPlace(i, input) + for j := 0; j < h.params.width; j++ { + h.sBox(j, input) } - h.matMulExternalInPlace(api, input) + h.matMulExternalInPlace(input) } return nil } + +// Compress applies the permutation on left and right and returns the right lane +// of the result. Panics if the permutation instance is not initialized with a +// width of 2. +// +// Implements the [hash.Compressor] interface for building a Merkle-Damgard +// hash construction. +func (h *Permutation) Compress(left, right frontend.Variable) frontend.Variable { + if h.params.width != 2 { + panic("poseidon2: Compress can only be used when t=2") + } + vars := [2]frontend.Variable{left, right} + if err := h.Permutation(vars[:]); err != nil { + panic(err) // this would never happen + } + return vars[1] +} diff --git a/std/permutation/poseidon2/poseidon2_test.go b/std/permutation/poseidon2/poseidon2_test.go index f14c07813..ed0a3b807 100644 --- a/std/permutation/poseidon2/poseidon2_test.go +++ b/std/permutation/poseidon2/poseidon2_test.go @@ -1,6 +1,7 @@ package poseidon import ( + "fmt" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -35,17 +36,18 @@ type Poseidon2Circuit struct { } type circuitParams struct { - seed string - rf int - rp int - t int - d int - id ecc.ID + rf int + rp int + t int + id ecc.ID } func (c *Poseidon2Circuit) Define(api frontend.API) error { - h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.seed, c.params.id) - h.Permutation(api, c.Input) + h, err := NewPoseidon2FromParameters(api, c.params.t, c.params.rf, c.params.rp) + if err != nil { + return fmt.Errorf("could not create poseidon2 hasher: %w", err) + } + h.Permutation(c.Input) for i := 0; i < len(c.Input); i++ { api.AssertIsEqual(c.Output[i], c.Input[i]) } @@ -57,22 +59,22 @@ func TestPoseidon2(t *testing.T) { assert := test.NewAssert(t) params := make(map[ecc.ID]circuitParams) - params[ecc.BN254] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BN254} - params[ecc.BLS12_381] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BLS12_381} - params[ecc.BLS12_377] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 17, id: ecc.BLS12_377} - params[ecc.BW6_761] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BW6_761} - params[ecc.BW6_633] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BW6_633} - params[ecc.BLS24_315] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 5, id: ecc.BLS24_315} - params[ecc.BLS24_317] = circuitParams{seed: "seed", rf: 8, rp: 56, t: 3, d: 7, id: ecc.BLS24_317} + params[ecc.BN254] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BN254} + params[ecc.BLS12_381] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS12_381} + params[ecc.BLS12_377] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS12_377} + params[ecc.BW6_761] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BW6_761} + params[ecc.BW6_633] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BW6_633} + params[ecc.BLS24_315] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS24_315} + params[ecc.BLS24_317] = circuitParams{rf: 8, rp: 56, t: 3, id: ecc.BLS24_317} { var circuit, validWitness Poseidon2Circuit - h := poseidonbn254.NewHash( + h := poseidonbn254.NewPermutation( params[ecc.BN254].t, params[ecc.BN254].rf, params[ecc.BN254].rp, - "seed") + ) var in, out [3]frbn254.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -101,11 +103,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls12377.NewHash( + h := poseidonbls12377.NewPermutation( params[ecc.BLS12_377].t, params[ecc.BLS12_377].rf, params[ecc.BLS12_377].rp, - "seed") + ) var in, out [3]frbls12377.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -134,11 +136,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls12381.NewHash( + h := poseidonbls12381.NewPermutation( params[ecc.BLS12_381].t, params[ecc.BLS12_381].rf, params[ecc.BLS12_381].rp, - "seed") + ) var in, out [3]frbls12381.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -167,11 +169,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6633.NewHash( + h := poseidonbw6633.NewPermutation( params[ecc.BW6_633].t, params[ecc.BW6_633].rf, params[ecc.BW6_633].rp, - "seed") + ) var in, out [3]frbw6633.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -200,11 +202,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6633.NewHash( + h := poseidonbw6633.NewPermutation( params[ecc.BW6_633].t, params[ecc.BW6_633].rf, params[ecc.BW6_633].rp, - "seed") + ) var in, out [3]frbw6633.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -233,11 +235,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6761.NewHash( + h := poseidonbw6761.NewPermutation( params[ecc.BW6_761].t, params[ecc.BW6_761].rf, params[ecc.BW6_761].rp, - "seed") + ) var in, out [3]frbw6761.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -266,11 +268,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls24315.NewHash( + h := poseidonbls24315.NewPermutation( params[ecc.BLS24_315].t, params[ecc.BLS24_315].rf, params[ecc.BLS24_315].rp, - "seed") + ) var in, out [3]frbls24315.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -299,11 +301,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls24317.NewHash( + h := poseidonbls24317.NewPermutation( params[ecc.BLS24_317].t, params[ecc.BLS24_317].rf, params[ecc.BLS24_317].rp, - "seed") + ) var in, out [3]frbls24317.Element for i := 0; i < 3; i++ { in[i].SetRandom()