-
Notifications
You must be signed in to change notification settings - Fork 419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perf: optimize selector.Mux with recursive BinaryMux for various sizes #1420
Conversation
pull in latest master codes from upstream
so that we are adding a new function MuxCapped instead of replacing Mux
Sorry for the back-and-forth. I just realized that |
when n == 1, should assert sel == 0 |
1 similar comment
when n == 1, should assert sel == 0 |
Suggested edit: diff --git a/std/selector/multiplexer.go b/std/selector/multiplexer.go
index c3f6c79ee..b364dd70b 100644
--- a/std/selector/multiplexer.go
+++ b/std/selector/multiplexer.go
@@ -56,7 +56,7 @@ func Map(api frontend.API, queryKey frontend.Variable,
// inputs, otherwise the proof will fail.
func Mux(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable) frontend.Variable {
n := uint(len(inputs))
- nbBits := binary.Len(n - 1)
+ nbBits := binary.Len(n - 1) // we use n-1 as sel is 0-indexed
selBits := bits.ToBinary(api, sel, bits.WithNbDigits(nbBits)) // binary decomposition ensures sel < 2^nbBits
// We use BinaryMux when len(inputs) is a power of 2.
|
Suggested edit: diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv
index 10e9a3edf..a56593813 100644
--- a/internal/stats/latest_stats.csv
+++ b/internal/stats/latest_stats.csv
@@ -251,3 +251,73 @@ scalar_mul_secp256k1,bls24_315,plonk,0,0
scalar_mul_secp256k1,bls24_317,plonk,0,0
scalar_mul_secp256k1,bw6_761,plonk,0,0
scalar_mul_secp256k1,bw6_633,plonk,0,0
+selector/binaryMux_4,bn254,groth16,5,3
+selector/binaryMux_4,bls12_377,groth16,5,3
+selector/binaryMux_4,bls12_381,groth16,5,3
+selector/binaryMux_4,bls24_315,groth16,5,3
+selector/binaryMux_4,bls24_317,groth16,5,3
+selector/binaryMux_4,bw6_761,groth16,5,3
+selector/binaryMux_4,bw6_633,groth16,5,3
+selector/binaryMux_4,bn254,plonk,11,9
+selector/binaryMux_4,bls12_377,plonk,11,9
+selector/binaryMux_4,bls12_381,plonk,11,9
+selector/binaryMux_4,bls24_315,plonk,11,9
+selector/binaryMux_4,bls24_317,plonk,11,9
+selector/binaryMux_4,bw6_761,plonk,11,9
+selector/binaryMux_4,bw6_633,plonk,11,9
+selector/binaryMux_8,bn254,groth16,10,7
+selector/binaryMux_8,bls12_377,groth16,10,7
+selector/binaryMux_8,bls12_381,groth16,10,7
+selector/binaryMux_8,bls24_315,groth16,10,7
+selector/binaryMux_8,bls24_317,groth16,10,7
+selector/binaryMux_8,bw6_761,groth16,10,7
+selector/binaryMux_8,bw6_633,groth16,10,7
+selector/binaryMux_8,bn254,plonk,24,21
+selector/binaryMux_8,bls12_377,plonk,24,21
+selector/binaryMux_8,bls12_381,plonk,24,21
+selector/binaryMux_8,bls24_315,plonk,24,21
+selector/binaryMux_8,bls24_317,plonk,24,21
+selector/binaryMux_8,bw6_761,plonk,24,21
+selector/binaryMux_8,bw6_633,plonk,24,21
+selector/mux_3,bn254,groth16,8,6
+selector/mux_3,bls12_377,groth16,8,6
+selector/mux_3,bls12_381,groth16,8,6
+selector/mux_3,bls24_315,groth16,8,6
+selector/mux_3,bls24_317,groth16,8,6
+selector/mux_3,bw6_761,groth16,8,6
+selector/mux_3,bw6_633,groth16,8,6
+selector/mux_3,bn254,plonk,15,13
+selector/mux_3,bls12_377,plonk,15,13
+selector/mux_3,bls12_381,plonk,15,13
+selector/mux_3,bls24_315,plonk,15,13
+selector/mux_3,bls24_317,plonk,15,13
+selector/mux_3,bw6_761,plonk,15,13
+selector/mux_3,bw6_633,plonk,15,13
+selector/mux_4,bn254,groth16,6,5
+selector/mux_4,bls12_377,groth16,6,5
+selector/mux_4,bls12_381,groth16,6,5
+selector/mux_4,bls24_315,groth16,6,5
+selector/mux_4,bls24_317,groth16,6,5
+selector/mux_4,bw6_761,groth16,6,5
+selector/mux_4,bw6_633,groth16,6,5
+selector/mux_4,bn254,plonk,13,12
+selector/mux_4,bls12_377,plonk,13,12
+selector/mux_4,bls12_381,plonk,13,12
+selector/mux_4,bls24_315,plonk,13,12
+selector/mux_4,bls24_317,plonk,13,12
+selector/mux_4,bw6_761,plonk,13,12
+selector/mux_4,bw6_633,plonk,13,12
+selector/mux_5,bn254,groth16,12,10
+selector/mux_5,bls12_377,groth16,12,10
+selector/mux_5,bls12_381,groth16,12,10
+selector/mux_5,bls24_315,groth16,12,10
+selector/mux_5,bls24_317,groth16,12,10
+selector/mux_5,bw6_761,groth16,12,10
+selector/mux_5,bw6_633,groth16,12,10
+selector/mux_5,bn254,plonk,25,23
+selector/mux_5,bls12_377,plonk,25,23
+selector/mux_5,bls12_381,plonk,25,23
+selector/mux_5,bls24_315,plonk,25,23
+selector/mux_5,bls24_317,plonk,25,23
+selector/mux_5,bw6_761,plonk,25,23
+selector/mux_5,bw6_633,plonk,25,23
diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go
index 00cd92a1d..ccab6700a 100644
--- a/internal/stats/snippet.go
+++ b/internal/stats/snippet.go
@@ -16,6 +16,7 @@ import (
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
+ "github.com/consensys/gnark/std/selector"
)
var (
@@ -311,6 +312,26 @@ func initSnippets() {
}, ecc.BN254)
+ registerSnippet("selector/mux_3", func(api frontend.API, newVariable func() frontend.Variable) {
+ selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable())
+ })
+
+ registerSnippet("selector/mux_4", func(api frontend.API, newVariable func() frontend.Variable) {
+ selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
+ })
+
+ registerSnippet("selector/mux_5", func(api frontend.API, newVariable func() frontend.Variable) {
+ selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
+ })
+
+ registerSnippet("selector/binaryMux_4", func(api frontend.API, newVariable func() frontend.Variable) {
+ selector.BinaryMux(api, []frontend.Variable{newVariable(), newVariable()}, []frontend.Variable{newVariable(), newVariable(), newVariable(), newVariable()})
+ })
+
+ registerSnippet("selector/binaryMux_8", func(api frontend.API, newVariable func() frontend.Variable) {
+ selector.BinaryMux(api, []frontend.Variable{newVariable(), newVariable(), newVariable()}, []frontend.Variable{newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable()})
+ })
+
}
type snippetCircuit struct {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! See the few recommended edits, one is just one inline comment which made it clearer to me and another one adds mux to the statistics test so that we know when we have regression (or optimization).
Currently not yet approving, I want to test out one thing.
Suggested edit: diff --git a/std/selector/multiplexer2_test.go b/std/selector/multiplexer2_test.go
index e9ebc4fcf..dbff5cc57 100644
--- a/std/selector/multiplexer2_test.go
+++ b/std/selector/multiplexer2_test.go
@@ -4,9 +4,7 @@ import (
binary "math/bits"
"testing"
- "github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
- "github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/test"
)
@@ -108,57 +106,3 @@ func (c *largeCircuit3) Define(api frontend.API) error {
api.AssertIsEqual(s, c.Expected)
return nil
}
-
-func TestBenchMux1(t *testing.T) {
- for i := 2; i < 900; i++ {
- a, b, c := testBenchMux(t, i)
- if a > b || a > c {
- t.Logf("warning: %v, %v, %v, %v\n", i, a, b, c)
- }
- t.Logf("%v, %v, %v, %v\n", i, a, b, c)
- }
-}
-
-func TestBenchMux2(t *testing.T) {
- for i := 1; i < 20; i++ {
- a, b, c := testBenchMux(t, 1<<i)
- if a >= c {
- t.Logf("warning %v, %v, %v, %v\n", 1<<i, a, b, c)
- }
- t.Logf("%v, %v, %v, %v\n", 1<<i, a, b, c)
- }
-}
-
-func TestBenchMux3(t *testing.T) {
- a, b, c := testBenchMux(t, 0b111111111111111111111)
- t.Logf("%v, %v, %v\n", a, b, c)
-}
-
-func testBenchMux(t *testing.T, len int) (int, int, int) {
- assert := test.NewAssert(t)
- circuit := &muxCircuit{
- Length: len,
- Input: make([]frontend.Variable, len),
- }
-
- cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit)
- assert.NoError(err)
-
- circuit2 := &largeCircuit2{
- Length: len,
- Input: make([]frontend.Variable, len),
- }
-
- cs2, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit2)
- assert.NoError(err)
-
- circuit3 := &largeCircuit3{
- Length: len,
- Input: make([]frontend.Variable, len),
- }
-
- cs3, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit3)
- assert.NoError(err)
-
- return cs.GetNbConstraints(), cs2.GetNbConstraints(), cs3.GetNbConstraints()
-}
|
Suggested edit: diff --git a/std/selector/multiplexer2_test.go b/std/selector/multiplexer2_test.go
index e9ebc4fcf..347f09c1c 100644
--- a/std/selector/multiplexer2_test.go
+++ b/std/selector/multiplexer2_test.go
@@ -25,56 +25,6 @@ func mux3(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable)
return dotProduct(api, inputs, Decoder(api, len(inputs), sel))
}
-type muxCircuit struct {
- Sel frontend.Variable
- Input []frontend.Variable
- Expected frontend.Variable
-
- Length int
-}
-
-func (c *muxCircuit) Define(api frontend.API) error {
- if len(c.Input) != c.Length {
- panic("invalid length")
- }
-
- s := Mux(api, c.Sel, c.Input...)
- api.AssertIsEqual(s, c.Expected)
- return nil
-}
-
-func TestMux100(t *testing.T) {
- for i := 0; i < 100; i++ {
- testMux(t, 100, i)
- }
-}
-
-func testMux(t *testing.T, len int, sel int) {
- assert := test.NewAssert(t)
- circuit := &muxCircuit{
- Length: len,
- Input: make([]frontend.Variable, len),
- }
-
- inputs := make([]frontend.Variable, len)
- for i := 0; i < len; i++ {
- inputs[i] = frontend.Variable(i)
- }
-
- assert.CheckCircuit(circuit,
- test.WithValidAssignment(&muxCircuit{
- Sel: sel,
- Input: inputs,
- Expected: sel,
- }),
- test.WithInvalidAssignment(&muxCircuit{
- Sel: 3000,
- Input: inputs,
- Expected: sel,
- }),
- )
-}
-
type largeCircuit2 struct {
Sel frontend.Variable
Input []frontend.Variable
diff --git a/std/selector/multiplexer_test.go b/std/selector/multiplexer_test.go
index c89f6aefa..4ae0c387b 100644
--- a/std/selector/multiplexer_test.go
+++ b/std/selector/multiplexer_test.go
@@ -1,28 +1,26 @@
-package selector_test
+package selector
import (
+ "fmt"
+ "math/rand/v2"
"testing"
-
- "github.com/consensys/gnark-crypto/ecc"
- "github.com/consensys/gnark/frontend/cs/r1cs"
+ "time"
"github.com/consensys/gnark/frontend"
- "github.com/consensys/gnark/std/selector"
"github.com/consensys/gnark/test"
)
type muxCircuit struct {
- SEL frontend.Variable
- I0, I1, I2, I3, I4 frontend.Variable
- OUT frontend.Variable
+ Sel frontend.Variable
+ Input []frontend.Variable
+ Expected frontend.Variable
+
+ Length int
}
func (c *muxCircuit) Define(api frontend.API) error {
-
- out := selector.Mux(api, c.SEL, c.I0, c.I1, c.I2, c.I3, c.I4)
-
- api.AssertIsEqual(out, c.OUT)
-
+ s := Mux(api, c.Sel, c.Input...)
+ api.AssertIsEqual(s, c.Expected)
return nil
}
@@ -34,46 +32,61 @@ type ignoredOutputMuxCircuit struct {
func (c *ignoredOutputMuxCircuit) Define(api frontend.API) error {
// We ignore the output
- _ = selector.Mux(api, c.SEL, c.I0, c.I1, c.I2)
-
- return nil
-}
-
-type mux2to1Circuit struct {
- SEL frontend.Variable
- I0, I1 frontend.Variable
- OUT frontend.Variable
-}
+ _ = Mux(api, c.SEL, c.I0, c.I1, c.I2)
-func (c *mux2to1Circuit) Define(api frontend.API) error {
- // We ignore the output
- out := selector.Mux(api, c.SEL, c.I0, c.I1)
- api.AssertIsEqual(out, c.OUT)
return nil
}
-type mux4to1Circuit struct {
- SEL frontend.Variable
- In [4]frontend.Variable
- OUT frontend.Variable
-}
-
-func (c *mux4to1Circuit) Define(api frontend.API) error {
- out := selector.Mux(api, c.SEL, c.In[:]...)
- api.AssertIsEqual(out, c.OUT)
- return nil
+func testMux(assert *test.Assert, len int, sel int) {
+ rng := rand.New(rand.NewPCG(uint64(time.Now().Unix()), 1)) // seed the random generator
+ circuit := &muxCircuit{
+ Input: make([]frontend.Variable, len),
+ }
+
+ inputs := make([]frontend.Variable, len)
+ for i := range len {
+ inputs[i] = frontend.Variable(rng.Uint64())
+ }
+ // out-range invalid selector
+ outRangeSel := uint64(len) + rng.Uint64N(100)
+ opts := []test.TestingOption{
+ test.WithValidAssignment(&muxCircuit{
+ Sel: sel,
+ Input: inputs,
+ Expected: inputs[sel],
+ }),
+ test.WithInvalidAssignment(&muxCircuit{
+ Sel: outRangeSel,
+ Input: inputs,
+ Expected: sel,
+ }),
+ }
+
+ // in-range invalid selector
+ if len > 1 {
+ invalidSel := rng.Uint64N(uint64(len))
+ for invalidSel == uint64(sel) {
+ invalidSel = rng.Uint64N(uint64(len))
+ }
+ opts = append(opts, test.WithInvalidAssignment(&muxCircuit{
+ Sel: invalidSel,
+ Input: inputs,
+ Expected: sel,
+ }))
+ }
+
+ assert.CheckCircuit(circuit, opts...)
}
func TestMux(t *testing.T) {
assert := test.NewAssert(t)
-
- assert.CheckCircuit(&muxCircuit{},
- test.WithValidAssignment(&muxCircuit{SEL: 2, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 12}),
- test.WithValidAssignment(&muxCircuit{SEL: 0, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 10}),
- test.WithValidAssignment(&muxCircuit{SEL: 4, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}),
- test.WithInvalidAssignment(&muxCircuit{SEL: 5, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}),
- test.WithInvalidAssignment(&muxCircuit{SEL: 0, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 21}),
- )
+ for len := range 9 {
+ for sel := range len + 1 {
+ assert.Run(func(assert *test.Assert) {
+ testMux(assert, len+1, sel)
+ }, fmt.Sprintf("len=%d/sel=%d", len+1, sel))
+ }
+ }
assert.CheckCircuit(&ignoredOutputMuxCircuit{},
test.WithValidAssignment(&ignoredOutputMuxCircuit{SEL: 0, I0: 0, I1: 1, I2: 2}),
@@ -81,39 +94,6 @@ func TestMux(t *testing.T) {
test.WithInvalidAssignment(&ignoredOutputMuxCircuit{SEL: 3, I0: 0, I1: 1, I2: 2}),
test.WithInvalidAssignment(&ignoredOutputMuxCircuit{SEL: -1, I0: 0, I1: 1, I2: 2}),
)
-
- assert.CheckCircuit(&mux2to1Circuit{},
- test.WithValidAssignment(&mux2to1Circuit{SEL: 1, I0: 10, I1: 20, OUT: 20}),
- test.WithValidAssignment(&mux2to1Circuit{SEL: 0, I0: 10, I1: 20, OUT: 10}),
- test.WithInvalidAssignment(&mux2to1Circuit{SEL: 2, I0: 10, I1: 20, OUT: 20}),
- )
-
- assert.CheckCircuit(&mux4to1Circuit{},
- test.WithValidAssignment(&mux4to1Circuit{
- SEL: 3,
- In: [4]frontend.Variable{11, 22, 33, 44},
- OUT: 44,
- }),
- test.WithValidAssignment(&mux4to1Circuit{
- SEL: 1,
- In: [4]frontend.Variable{11, 22, 33, 44},
- OUT: 22,
- }),
- test.WithValidAssignment(&mux4to1Circuit{
- SEL: 0,
- In: [4]frontend.Variable{11, 22, 33, 44},
- OUT: 11,
- }),
- test.WithInvalidAssignment(&mux4to1Circuit{
- SEL: 4,
- In: [4]frontend.Variable{11, 22, 33, 44},
- OUT: 44,
- }),
- )
-
- cs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &mux4to1Circuit{})
- // (4 - 1) + (2 + 1) + 1 == 7
- assert.Equal(7, cs.GetNbConstraints())
}
// Map tests:
@@ -126,7 +106,7 @@ type mapCircuit struct {
func (c *mapCircuit) Define(api frontend.API) error {
- out := selector.Map(api, c.SEL,
+ out := Map(api, c.SEL,
[]frontend.Variable{c.K0, c.K1, c.K2, c.K3},
[]frontend.Variable{c.V0, c.V1, c.V2, c.V3})
@@ -143,7 +123,7 @@ type ignoredOutputMapCircuit struct {
func (c *ignoredOutputMapCircuit) Define(api frontend.API) error {
- _ = selector.Map(api, c.SEL,
+ _ = Map(api, c.SEL,
[]frontend.Variable{c.K0, c.K1},
[]frontend.Variable{c.V0, c.V1})
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the comments and suggested edits. I wasn't able to push directly to the PR.
It looks very good, but please incorporate the changes and I'll merge.
All incorporated except for the range clause. Also, after the suggested editing, there is no test remaining in the In case you see further changes necessary, I have invited you to the repos https://github.com/lightec-xyz/gnark and https://github.com/lightec-xyz/gnark-crypto Please feel free to make changes. @ivokub |
Thanks for the access. I dropped multiplexer2_test.go and Imo there is still room for optimization, I don't have good idea right now though how to do it nicely. There is nbBits := binary.Len(n - 1) // we use n-1 as sel is 0-indexed
selBits := bits.ToBinary(api, sel, bits.WithNbDigits(nbBits)) // binary decomposition ensures sel < 2^nbBits and
and the second call duplicates a bit the first one, but instead it computes binary decomposition of But in practice we could do the check directly on I'll wait for the CI and then merge. I created new issue for it #1434 |
Description
(edited)
This PR is an optimized complement to the
selector.Mux
function. Instead of dot product, we recursively doselector.BinaryMux
even when the length of inputs is not 2's power, each time with a different size depending on how the length is decomposed into binary.A
bound
parameter is required to constructBoundedComparator
. Users are responsible for ensuring|sel - len(inputs)| <= absDiffUpp
as required bycmp.NewBoundedComparator
.How has this been tested?
How has this been benchmarked?
Added new tests showing the saved constraints. We moved the replaced implementation to test scope for benchmarking purpose, and count the constraint counts from 2 to 899. The constraint counts have been improved when the
inputs
array is large enough (> 16
), but this is subject to the bound passed to the test case (TestBenchMux1
), and users should have the option to choose based on benchmarking data.Additionally we compare
BinaryMux
againstdotProduct
when n is 2's power (TestBenchMux2
).BinaryMux
is strictly better.Checklist:
golangci-lint
does not output errors locally