Skip to content

Commit 7232801

Browse files
committed
add Test_SHA3FixedLengthSum_WithMinLen_VS_Zero
1 parent cd0f4f7 commit 7232801

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

std/hash/sha2/sha2_test.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ type sha2FixedLengthCircuit struct {
5555
Expected [32]uints.U8
5656
}
5757

58+
const (
59+
minLen = 55
60+
maxLen = 144
61+
)
62+
5863
func (c *sha2FixedLengthCircuit) Define(api frontend.API) error {
5964
h, err := New(api)
6065
if err != nil {
@@ -65,7 +70,7 @@ func (c *sha2FixedLengthCircuit) Define(api frontend.API) error {
6570
return err
6671
}
6772
h.Write(c.In)
68-
res := h.FixedLengthSum(0, c.Length)
73+
res := h.FixedLengthSum(minLen, c.Length)
6974
if len(res) != 32 {
7075
return fmt.Errorf("not 32 bytes")
7176
}
@@ -76,15 +81,16 @@ func (c *sha2FixedLengthCircuit) Define(api frontend.API) error {
7681
}
7782

7883
func TestSHA2FixedLengthSum(t *testing.T) {
79-
bts := make([]byte, 144)
84+
circuit := &sha2FixedLengthCircuit{In: make([]uints.U8, maxLen)}
85+
bts := make([]byte, maxLen)
8086
length := 56
8187
dgst := sha256.Sum256(bts[:length])
8288
witness := sha2FixedLengthCircuit{
8389
In: uints.NewU8Array(bts),
8490
Length: length,
8591
}
8692
copy(witness.Expected[:], uints.NewU8Array(dgst[:]))
87-
err := test.IsSolved(&sha2FixedLengthCircuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField())
93+
err := test.IsSolved(circuit, &witness, ecc.BN254.ScalarField())
8894
if err != nil {
8995
t.Fatal(err)
9096
}

std/hash/sha3/sha3_test.go

+39-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/consensys/gnark-crypto/ecc"
1212
"github.com/consensys/gnark/frontend"
13+
"github.com/consensys/gnark/frontend/cs/scs"
1314
zkhash "github.com/consensys/gnark/std/hash"
1415
"github.com/consensys/gnark/std/math/uints"
1516
"github.com/consensys/gnark/test"
@@ -158,12 +159,12 @@ func TestSHA3FixedLengthSum(t *testing.T) {
158159
}
159160

160161
const (
161-
minLen = 299
162-
maxLen = 310
162+
minLen = 1680
163+
maxLen = 1710
163164
)
164165

165166
type sha3FixedLengthSumWithMinLenCircuit struct {
166-
In [maxLen]uints.U8
167+
In []uints.U8
167168
Expected []uints.U8
168169
Length frontend.Variable
169170
hasher string
@@ -182,7 +183,7 @@ func (c *sha3FixedLengthSumWithMinLenCircuit) Define(api frontend.API) error {
182183
if err != nil {
183184
return err
184185
}
185-
h.Write(c.In[:])
186+
h.Write(c.In)
186187
res := h.FixedLengthSum(minLen, c.Length)
187188

188189
for i := range c.Expected {
@@ -207,13 +208,13 @@ func TestSHA3FixedLengthSumWithMinLen(t *testing.T) {
207208
h.Write(in[:length])
208209
expected := h.Sum(nil)
209210

210-
circuit := &sha3FixedLengthSumCircuit{
211+
circuit := &sha3FixedLengthSumWithMinLenCircuit{
211212
In: make([]uints.U8, maxLen),
212213
Expected: make([]uints.U8, len(expected)),
213214
hasher: name,
214215
}
215216

216-
witness := &sha3FixedLengthSumCircuit{
217+
witness := &sha3FixedLengthSumWithMinLenCircuit{
217218
In: uints.NewU8Array(in),
218219
Expected: uints.NewU8Array(expected),
219220
Length: length,
@@ -227,3 +228,35 @@ func TestSHA3FixedLengthSumWithMinLen(t *testing.T) {
227228
}, fmt.Sprintf("hash=%s", name))
228229
}
229230
}
231+
232+
func Test_SHA3FixedLengthSum_WithMinLen_VS_Zero(t *testing.T) {
233+
assert := test.NewAssert(t)
234+
235+
for name := range testCases {
236+
name := name
237+
strategy := testCases[name]
238+
h := strategy.native()
239+
sumLen := h.Size()
240+
241+
circuit1 := &sha3FixedLengthSumCircuit{
242+
In: make([]uints.U8, maxLen),
243+
Expected: make([]uints.U8, sumLen),
244+
hasher: name,
245+
}
246+
247+
cs1, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit1)
248+
assert.NoError(err)
249+
250+
circuit2 := &sha3FixedLengthSumWithMinLenCircuit{
251+
In: make([]uints.U8, maxLen),
252+
Expected: make([]uints.U8, sumLen),
253+
hasher: name,
254+
}
255+
256+
cs2, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit2)
257+
assert.NoError(err)
258+
259+
fmt.Printf("maxLen=%d, minLen=%d, hash=%s, nbConstraints: %d vs %d(withMinLen)\n",
260+
maxLen, minLen, name, cs1.GetNbConstraints(), cs2.GetNbConstraints())
261+
}
262+
}

0 commit comments

Comments
 (0)