Skip to content

Commit b152af1

Browse files
authored
feat: allow taking benchmarking instance in NewAssert (#1607)
1 parent 3f994db commit b152af1

File tree

3 files changed

+43
-16
lines changed

3 files changed

+43
-16
lines changed

test/assert.go

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,53 @@ var (
3232
// Assert is a helper to test circuits
3333
type Assert struct {
3434
t *testing.T
35+
b *testing.B
3536
*require.Assertions
3637
}
3738

38-
// NewAssert returns an Assert helper embedding a testify/require object for convenience
39+
// NewAssert returns an Assert helper embedding a testify/require object for convenience.
40+
// It accepts either a *testing.T or *testing.B object.
3941
//
40-
// The Assert object caches the compiled circuit:
41-
//
42-
// the first call to assert.ProverSucceeded/Failed will compile the circuit for n curves, m backends
43-
// and subsequent calls will re-use the result of the compilation, if available.
44-
func NewAssert(t *testing.T) *Assert {
45-
return &Assert{t: t, Assertions: require.New(t)}
42+
// The Assert object caches the compiled circuit. This means that the first call
43+
// to [Assert.CheckCircuit] will compile the circuit for n curves, m
44+
// backends and subsequent calls will re-use the result of the compilation, if
45+
// available. Be careful when benchmarking!
46+
func NewAssert(tb testing.TB) *Assert {
47+
switch t := (tb).(type) {
48+
case *testing.T:
49+
return &Assert{t: t, Assertions: require.New(t)}
50+
case *testing.B:
51+
return &Assert{b: t, Assertions: require.New(t)}
52+
default:
53+
panic("unknown testing type")
54+
}
4655
}
4756

4857
// Run runs the test function fn as a subtest. The subtest is parametrized by
4958
// the description strings descs.
5059
func (assert *Assert) Run(fn func(assert *Assert), descs ...string) {
5160
desc := strings.Join(descs, "/")
52-
assert.t.Run(desc, func(t *testing.T) {
53-
assert := &Assert{t, require.New(t)}
54-
fn(assert)
55-
})
61+
if assert.b != nil {
62+
assert.b.Run(desc, func(b *testing.B) {
63+
assert := &Assert{b: b, Assertions: require.New(b)}
64+
fn(assert)
65+
})
66+
} else {
67+
assert.t.Run(desc, func(t *testing.T) {
68+
assert := &Assert{t: t, Assertions: require.New(t)}
69+
fn(assert)
70+
})
71+
}
5672
}
5773

5874
// Log logs using the test instance logger.
5975
func (assert *Assert) Log(v ...interface{}) {
60-
assert.t.Log(v...)
76+
if assert.b != nil {
77+
assert.b.Log(v...)
78+
return
79+
} else {
80+
assert.t.Log(v...)
81+
}
6182
}
6283

6384
// ProverSucceeded is deprecated: use [Assert.CheckCircuit] instead

test/assert_checkcircuit.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ func (assert *Assert) CheckCircuit(circuit frontend.Circuit, opts ...TestingOpti
101101
}
102102

103103
// we need to run the setup, prove and verify and check serialization
104-
assert.t.Parallel()
104+
if assert.t != nil {
105+
assert.t.Parallel()
106+
}
105107

106108
var concreteBackend tBackend
107109

test/assert_solidity.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ func (assert *Assert) solidityVerification(b backend.ID, vk solidity.VerifyingKe
2424
if !SolcCheck || len(validPublicWitness.Vector().(fr_bn254.Vector)) == 0 {
2525
return // nothing to check, will make solc fail.
2626
}
27-
assert.t.Helper()
27+
if assert.b != nil {
28+
assert.b.Helper()
29+
} else {
30+
assert.t.Helper()
31+
}
2832

2933
// make temp dir
3034
tmpDir, err := os.MkdirTemp("", "gnark-solidity-check*")
@@ -44,7 +48,7 @@ func (assert *Assert) solidityVerification(b backend.ID, vk solidity.VerifyingKe
4448
// generate assets
4549
// gnark-solidity-checker generate --dir tmpdir --solidity contract_g16.sol
4650
cmd := exec.Command("gnark-solidity-checker", "generate", "--dir", tmpDir, "--solidity", "gnark_verifier.sol")
47-
assert.t.Log("running ", cmd.String())
51+
assert.Log("running ", cmd.String())
4852
out, err := cmd.CombinedOutput()
4953
assert.NoError(err, string(out))
5054

@@ -90,7 +94,7 @@ func (assert *Assert) solidityVerification(b backend.ID, vk solidity.VerifyingKe
9094
// verify proof
9195
// gnark-solidity-checker verify --dir tmdir --groth16 --nb-public-inputs 1 --proof 1234 --public-inputs dead
9296
cmd = exec.Command("gnark-solidity-checker", checkerOpts...)
93-
assert.t.Log("running ", cmd.String())
97+
assert.Log("running ", cmd.String())
9498
out, err = cmd.CombinedOutput()
9599
assert.NoError(err, string(out))
96100
}

0 commit comments

Comments
 (0)