Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: optimize selector.Mux with recursive BinaryMux for various sizes #1420

Merged
merged 19 commits into from
Feb 28, 2025

Conversation

weijiguo
Copy link
Contributor

@weijiguo weijiguo commented Feb 13, 2025

Description

(edited)

This PR is an optimized complement to the selector.Mux function. Instead of dot product, we recursively do selector.BinaryMux even when the length of inputs is not 2's power, each time with a different size depending on how the length is decomposed into binary.

A bound parameter is required to construct BoundedComparator. Users are responsible for ensuring |sel - len(inputs)| <= absDiffUpp as required by cmp.NewBoundedComparator.

How has this been tested?

  • All existing tests

How has this been benchmarked?

  • Added new tests showing the saved constraints. We moved the replaced implementation to test scope for benchmarking purpose, and count the constraint counts from 2 to 899. The constraint counts have been improved when the inputs array is large enough (> 16), but this is subject to the bound passed to the test case (TestBenchMux1), and users should have the option to choose based on benchmarking data.

  • Additionally we compare BinaryMux against dotProduct when n is 2's power (TestBenchMux2). BinaryMux is strictly better.

Checklist:

  • I have performed a self-review of my code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • I did not modify files generated from templates
  • golangci-lint does not output errors locally
  • New and existing unit tests pass locally with my changes
  • Any dependent changes have been merged and published in downstream modules

@weijiguo weijiguo marked this pull request as draft February 13, 2025 06:21
so that we are adding a new function MuxCapped instead of replacing Mux
@weijiguo weijiguo changed the title Perf: optimize selector.Mux with recursive BinaryMux for various sizes Perf: added selector.MuxCapped with recursive BinaryMux for various sizes Feb 13, 2025
@weijiguo weijiguo marked this pull request as ready for review February 13, 2025 08:30
@weijiguo weijiguo marked this pull request as draft February 13, 2025 12:09
@weijiguo weijiguo changed the title Perf: added selector.MuxCapped with recursive BinaryMux for various sizes Perf: added selector.MuxBounded with recursive BinaryMux for various sizes Feb 14, 2025
@weijiguo weijiguo marked this pull request as ready for review February 14, 2025 14:32
@weijiguo weijiguo changed the title Perf: added selector.MuxBounded with recursive BinaryMux for various sizes Perf: optimize selector.Mux with recursive BinaryMux for various sizes Feb 14, 2025
@weijiguo
Copy link
Contributor Author

Sorry for the back-and-forth. I just realized that bits.ToBinary ensures some bound, which could be used by a BoundedComparator. I have changed the PR to update selector.Mux instead of adding a new function. The unit test TestBenchMux1 shows that this PR provides a strict improvement for all the cases with inputs length from 2 to 899.

@ggq89
Copy link
Contributor

ggq89 commented Feb 20, 2025

when n == 1, should assert sel == 0

1 similar comment
@ggq89
Copy link
Contributor

ggq89 commented Feb 20, 2025

when n == 1, should assert sel == 0

@ivokub
Copy link
Collaborator

ivokub commented Feb 27, 2025

Suggested edit:

diff --git a/std/selector/multiplexer.go b/std/selector/multiplexer.go
index c3f6c79ee..b364dd70b 100644
--- a/std/selector/multiplexer.go
+++ b/std/selector/multiplexer.go
@@ -56,7 +56,7 @@ func Map(api frontend.API, queryKey frontend.Variable,
 // inputs, otherwise the proof will fail.
 func Mux(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable) frontend.Variable {
 	n := uint(len(inputs))
-	nbBits := binary.Len(n - 1)
+	nbBits := binary.Len(n - 1)                                   // we use n-1 as sel is 0-indexed
 	selBits := bits.ToBinary(api, sel, bits.WithNbDigits(nbBits)) // binary decomposition ensures sel < 2^nbBits
 
 	// We use BinaryMux when len(inputs) is a power of 2.

@ivokub
Copy link
Collaborator

ivokub commented Feb 27, 2025

Suggested edit:

diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv
index 10e9a3edf..a56593813 100644
--- a/internal/stats/latest_stats.csv
+++ b/internal/stats/latest_stats.csv
@@ -251,3 +251,73 @@ scalar_mul_secp256k1,bls24_315,plonk,0,0
 scalar_mul_secp256k1,bls24_317,plonk,0,0
 scalar_mul_secp256k1,bw6_761,plonk,0,0
 scalar_mul_secp256k1,bw6_633,plonk,0,0
+selector/binaryMux_4,bn254,groth16,5,3
+selector/binaryMux_4,bls12_377,groth16,5,3
+selector/binaryMux_4,bls12_381,groth16,5,3
+selector/binaryMux_4,bls24_315,groth16,5,3
+selector/binaryMux_4,bls24_317,groth16,5,3
+selector/binaryMux_4,bw6_761,groth16,5,3
+selector/binaryMux_4,bw6_633,groth16,5,3
+selector/binaryMux_4,bn254,plonk,11,9
+selector/binaryMux_4,bls12_377,plonk,11,9
+selector/binaryMux_4,bls12_381,plonk,11,9
+selector/binaryMux_4,bls24_315,plonk,11,9
+selector/binaryMux_4,bls24_317,plonk,11,9
+selector/binaryMux_4,bw6_761,plonk,11,9
+selector/binaryMux_4,bw6_633,plonk,11,9
+selector/binaryMux_8,bn254,groth16,10,7
+selector/binaryMux_8,bls12_377,groth16,10,7
+selector/binaryMux_8,bls12_381,groth16,10,7
+selector/binaryMux_8,bls24_315,groth16,10,7
+selector/binaryMux_8,bls24_317,groth16,10,7
+selector/binaryMux_8,bw6_761,groth16,10,7
+selector/binaryMux_8,bw6_633,groth16,10,7
+selector/binaryMux_8,bn254,plonk,24,21
+selector/binaryMux_8,bls12_377,plonk,24,21
+selector/binaryMux_8,bls12_381,plonk,24,21
+selector/binaryMux_8,bls24_315,plonk,24,21
+selector/binaryMux_8,bls24_317,plonk,24,21
+selector/binaryMux_8,bw6_761,plonk,24,21
+selector/binaryMux_8,bw6_633,plonk,24,21
+selector/mux_3,bn254,groth16,8,6
+selector/mux_3,bls12_377,groth16,8,6
+selector/mux_3,bls12_381,groth16,8,6
+selector/mux_3,bls24_315,groth16,8,6
+selector/mux_3,bls24_317,groth16,8,6
+selector/mux_3,bw6_761,groth16,8,6
+selector/mux_3,bw6_633,groth16,8,6
+selector/mux_3,bn254,plonk,15,13
+selector/mux_3,bls12_377,plonk,15,13
+selector/mux_3,bls12_381,plonk,15,13
+selector/mux_3,bls24_315,plonk,15,13
+selector/mux_3,bls24_317,plonk,15,13
+selector/mux_3,bw6_761,plonk,15,13
+selector/mux_3,bw6_633,plonk,15,13
+selector/mux_4,bn254,groth16,6,5
+selector/mux_4,bls12_377,groth16,6,5
+selector/mux_4,bls12_381,groth16,6,5
+selector/mux_4,bls24_315,groth16,6,5
+selector/mux_4,bls24_317,groth16,6,5
+selector/mux_4,bw6_761,groth16,6,5
+selector/mux_4,bw6_633,groth16,6,5
+selector/mux_4,bn254,plonk,13,12
+selector/mux_4,bls12_377,plonk,13,12
+selector/mux_4,bls12_381,plonk,13,12
+selector/mux_4,bls24_315,plonk,13,12
+selector/mux_4,bls24_317,plonk,13,12
+selector/mux_4,bw6_761,plonk,13,12
+selector/mux_4,bw6_633,plonk,13,12
+selector/mux_5,bn254,groth16,12,10
+selector/mux_5,bls12_377,groth16,12,10
+selector/mux_5,bls12_381,groth16,12,10
+selector/mux_5,bls24_315,groth16,12,10
+selector/mux_5,bls24_317,groth16,12,10
+selector/mux_5,bw6_761,groth16,12,10
+selector/mux_5,bw6_633,groth16,12,10
+selector/mux_5,bn254,plonk,25,23
+selector/mux_5,bls12_377,plonk,25,23
+selector/mux_5,bls12_381,plonk,25,23
+selector/mux_5,bls24_315,plonk,25,23
+selector/mux_5,bls24_317,plonk,25,23
+selector/mux_5,bw6_761,plonk,25,23
+selector/mux_5,bw6_633,plonk,25,23
diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go
index 00cd92a1d..ccab6700a 100644
--- a/internal/stats/snippet.go
+++ b/internal/stats/snippet.go
@@ -16,6 +16,7 @@ import (
 	"github.com/consensys/gnark/std/hash/mimc"
 	"github.com/consensys/gnark/std/math/bits"
 	"github.com/consensys/gnark/std/math/emulated"
+	"github.com/consensys/gnark/std/selector"
 )
 
 var (
@@ -311,6 +312,26 @@ func initSnippets() {
 
 	}, ecc.BN254)
 
+	registerSnippet("selector/mux_3", func(api frontend.API, newVariable func() frontend.Variable) {
+		selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable())
+	})
+
+	registerSnippet("selector/mux_4", func(api frontend.API, newVariable func() frontend.Variable) {
+		selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
+	})
+
+	registerSnippet("selector/mux_5", func(api frontend.API, newVariable func() frontend.Variable) {
+		selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
+	})
+
+	registerSnippet("selector/binaryMux_4", func(api frontend.API, newVariable func() frontend.Variable) {
+		selector.BinaryMux(api, []frontend.Variable{newVariable(), newVariable()}, []frontend.Variable{newVariable(), newVariable(), newVariable(), newVariable()})
+	})
+
+	registerSnippet("selector/binaryMux_8", func(api frontend.API, newVariable func() frontend.Variable) {
+		selector.BinaryMux(api, []frontend.Variable{newVariable(), newVariable(), newVariable()}, []frontend.Variable{newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable()})
+	})
+
 }
 
 type snippetCircuit struct {

Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! See the few recommended edits, one is just one inline comment which made it clearer to me and another one adds mux to the statistics test so that we know when we have regression (or optimization).

Currently not yet approving, I want to test out one thing.

@ivokub ivokub self-requested a review February 27, 2025 13:42
@ivokub
Copy link
Collaborator

ivokub commented Feb 27, 2025

Suggested edit:

diff --git a/std/selector/multiplexer2_test.go b/std/selector/multiplexer2_test.go
index e9ebc4fcf..dbff5cc57 100644
--- a/std/selector/multiplexer2_test.go
+++ b/std/selector/multiplexer2_test.go
@@ -4,9 +4,7 @@ import (
 	binary "math/bits"
 	"testing"
 
-	"github.com/consensys/gnark-crypto/ecc"
 	"github.com/consensys/gnark/frontend"
-	"github.com/consensys/gnark/frontend/cs/scs"
 	"github.com/consensys/gnark/std/math/bits"
 	"github.com/consensys/gnark/test"
 )
@@ -108,57 +106,3 @@ func (c *largeCircuit3) Define(api frontend.API) error {
 	api.AssertIsEqual(s, c.Expected)
 	return nil
 }
-
-func TestBenchMux1(t *testing.T) {
-	for i := 2; i < 900; i++ {
-		a, b, c := testBenchMux(t, i)
-		if a > b || a > c {
-			t.Logf("warning: %v, %v, %v, %v\n", i, a, b, c)
-		}
-		t.Logf("%v, %v, %v, %v\n", i, a, b, c)
-	}
-}
-
-func TestBenchMux2(t *testing.T) {
-	for i := 1; i < 20; i++ {
-		a, b, c := testBenchMux(t, 1<<i)
-		if a >= c {
-			t.Logf("warning %v, %v, %v, %v\n", 1<<i, a, b, c)
-		}
-		t.Logf("%v, %v, %v, %v\n", 1<<i, a, b, c)
-	}
-}
-
-func TestBenchMux3(t *testing.T) {
-	a, b, c := testBenchMux(t, 0b111111111111111111111)
-	t.Logf("%v, %v, %v\n", a, b, c)
-}
-
-func testBenchMux(t *testing.T, len int) (int, int, int) {
-	assert := test.NewAssert(t)
-	circuit := &muxCircuit{
-		Length: len,
-		Input:  make([]frontend.Variable, len),
-	}
-
-	cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit)
-	assert.NoError(err)
-
-	circuit2 := &largeCircuit2{
-		Length: len,
-		Input:  make([]frontend.Variable, len),
-	}
-
-	cs2, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit2)
-	assert.NoError(err)
-
-	circuit3 := &largeCircuit3{
-		Length: len,
-		Input:  make([]frontend.Variable, len),
-	}
-
-	cs3, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit3)
-	assert.NoError(err)
-
-	return cs.GetNbConstraints(), cs2.GetNbConstraints(), cs3.GetNbConstraints()
-}

@ivokub
Copy link
Collaborator

ivokub commented Feb 27, 2025

Suggested edit:

diff --git a/std/selector/multiplexer2_test.go b/std/selector/multiplexer2_test.go
index e9ebc4fcf..347f09c1c 100644
--- a/std/selector/multiplexer2_test.go
+++ b/std/selector/multiplexer2_test.go
@@ -25,56 +25,6 @@ func mux3(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable)
 	return dotProduct(api, inputs, Decoder(api, len(inputs), sel))
 }
 
-type muxCircuit struct {
-	Sel      frontend.Variable
-	Input    []frontend.Variable
-	Expected frontend.Variable
-
-	Length int
-}
-
-func (c *muxCircuit) Define(api frontend.API) error {
-	if len(c.Input) != c.Length {
-		panic("invalid length")
-	}
-
-	s := Mux(api, c.Sel, c.Input...)
-	api.AssertIsEqual(s, c.Expected)
-	return nil
-}
-
-func TestMux100(t *testing.T) {
-	for i := 0; i < 100; i++ {
-		testMux(t, 100, i)
-	}
-}
-
-func testMux(t *testing.T, len int, sel int) {
-	assert := test.NewAssert(t)
-	circuit := &muxCircuit{
-		Length: len,
-		Input:  make([]frontend.Variable, len),
-	}
-
-	inputs := make([]frontend.Variable, len)
-	for i := 0; i < len; i++ {
-		inputs[i] = frontend.Variable(i)
-	}
-
-	assert.CheckCircuit(circuit,
-		test.WithValidAssignment(&muxCircuit{
-			Sel:      sel,
-			Input:    inputs,
-			Expected: sel,
-		}),
-		test.WithInvalidAssignment(&muxCircuit{
-			Sel:      3000,
-			Input:    inputs,
-			Expected: sel,
-		}),
-	)
-}
-
 type largeCircuit2 struct {
 	Sel      frontend.Variable
 	Input    []frontend.Variable
diff --git a/std/selector/multiplexer_test.go b/std/selector/multiplexer_test.go
index c89f6aefa..4ae0c387b 100644
--- a/std/selector/multiplexer_test.go
+++ b/std/selector/multiplexer_test.go
@@ -1,28 +1,26 @@
-package selector_test
+package selector
 
 import (
+	"fmt"
+	"math/rand/v2"
 	"testing"
-
-	"github.com/consensys/gnark-crypto/ecc"
-	"github.com/consensys/gnark/frontend/cs/r1cs"
+	"time"
 
 	"github.com/consensys/gnark/frontend"
-	"github.com/consensys/gnark/std/selector"
 	"github.com/consensys/gnark/test"
 )
 
 type muxCircuit struct {
-	SEL                frontend.Variable
-	I0, I1, I2, I3, I4 frontend.Variable
-	OUT                frontend.Variable
+	Sel      frontend.Variable
+	Input    []frontend.Variable
+	Expected frontend.Variable
+
+	Length int
 }
 
 func (c *muxCircuit) Define(api frontend.API) error {
-
-	out := selector.Mux(api, c.SEL, c.I0, c.I1, c.I2, c.I3, c.I4)
-
-	api.AssertIsEqual(out, c.OUT)
-
+	s := Mux(api, c.Sel, c.Input...)
+	api.AssertIsEqual(s, c.Expected)
 	return nil
 }
 
@@ -34,46 +32,61 @@ type ignoredOutputMuxCircuit struct {
 
 func (c *ignoredOutputMuxCircuit) Define(api frontend.API) error {
 	// We ignore the output
-	_ = selector.Mux(api, c.SEL, c.I0, c.I1, c.I2)
-
-	return nil
-}
-
-type mux2to1Circuit struct {
-	SEL    frontend.Variable
-	I0, I1 frontend.Variable
-	OUT    frontend.Variable
-}
+	_ = Mux(api, c.SEL, c.I0, c.I1, c.I2)
 
-func (c *mux2to1Circuit) Define(api frontend.API) error {
-	// We ignore the output
-	out := selector.Mux(api, c.SEL, c.I0, c.I1)
-	api.AssertIsEqual(out, c.OUT)
 	return nil
 }
 
-type mux4to1Circuit struct {
-	SEL frontend.Variable
-	In  [4]frontend.Variable
-	OUT frontend.Variable
-}
-
-func (c *mux4to1Circuit) Define(api frontend.API) error {
-	out := selector.Mux(api, c.SEL, c.In[:]...)
-	api.AssertIsEqual(out, c.OUT)
-	return nil
+func testMux(assert *test.Assert, len int, sel int) {
+	rng := rand.New(rand.NewPCG(uint64(time.Now().Unix()), 1)) // seed the random generator
+	circuit := &muxCircuit{
+		Input: make([]frontend.Variable, len),
+	}
+
+	inputs := make([]frontend.Variable, len)
+	for i := range len {
+		inputs[i] = frontend.Variable(rng.Uint64())
+	}
+	// out-range invalid selector
+	outRangeSel := uint64(len) + rng.Uint64N(100)
+	opts := []test.TestingOption{
+		test.WithValidAssignment(&muxCircuit{
+			Sel:      sel,
+			Input:    inputs,
+			Expected: inputs[sel],
+		}),
+		test.WithInvalidAssignment(&muxCircuit{
+			Sel:      outRangeSel,
+			Input:    inputs,
+			Expected: sel,
+		}),
+	}
+
+	// in-range invalid selector
+	if len > 1 {
+		invalidSel := rng.Uint64N(uint64(len))
+		for invalidSel == uint64(sel) {
+			invalidSel = rng.Uint64N(uint64(len))
+		}
+		opts = append(opts, test.WithInvalidAssignment(&muxCircuit{
+			Sel:      invalidSel,
+			Input:    inputs,
+			Expected: sel,
+		}))
+	}
+
+	assert.CheckCircuit(circuit, opts...)
 }
 
 func TestMux(t *testing.T) {
 	assert := test.NewAssert(t)
-
-	assert.CheckCircuit(&muxCircuit{},
-		test.WithValidAssignment(&muxCircuit{SEL: 2, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 12}),
-		test.WithValidAssignment(&muxCircuit{SEL: 0, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 10}),
-		test.WithValidAssignment(&muxCircuit{SEL: 4, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}),
-		test.WithInvalidAssignment(&muxCircuit{SEL: 5, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}),
-		test.WithInvalidAssignment(&muxCircuit{SEL: 0, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 21}),
-	)
+	for len := range 9 {
+		for sel := range len + 1 {
+			assert.Run(func(assert *test.Assert) {
+				testMux(assert, len+1, sel)
+			}, fmt.Sprintf("len=%d/sel=%d", len+1, sel))
+		}
+	}
 
 	assert.CheckCircuit(&ignoredOutputMuxCircuit{},
 		test.WithValidAssignment(&ignoredOutputMuxCircuit{SEL: 0, I0: 0, I1: 1, I2: 2}),
@@ -81,39 +94,6 @@ func TestMux(t *testing.T) {
 		test.WithInvalidAssignment(&ignoredOutputMuxCircuit{SEL: 3, I0: 0, I1: 1, I2: 2}),
 		test.WithInvalidAssignment(&ignoredOutputMuxCircuit{SEL: -1, I0: 0, I1: 1, I2: 2}),
 	)
-
-	assert.CheckCircuit(&mux2to1Circuit{},
-		test.WithValidAssignment(&mux2to1Circuit{SEL: 1, I0: 10, I1: 20, OUT: 20}),
-		test.WithValidAssignment(&mux2to1Circuit{SEL: 0, I0: 10, I1: 20, OUT: 10}),
-		test.WithInvalidAssignment(&mux2to1Circuit{SEL: 2, I0: 10, I1: 20, OUT: 20}),
-	)
-
-	assert.CheckCircuit(&mux4to1Circuit{},
-		test.WithValidAssignment(&mux4to1Circuit{
-			SEL: 3,
-			In:  [4]frontend.Variable{11, 22, 33, 44},
-			OUT: 44,
-		}),
-		test.WithValidAssignment(&mux4to1Circuit{
-			SEL: 1,
-			In:  [4]frontend.Variable{11, 22, 33, 44},
-			OUT: 22,
-		}),
-		test.WithValidAssignment(&mux4to1Circuit{
-			SEL: 0,
-			In:  [4]frontend.Variable{11, 22, 33, 44},
-			OUT: 11,
-		}),
-		test.WithInvalidAssignment(&mux4to1Circuit{
-			SEL: 4,
-			In:  [4]frontend.Variable{11, 22, 33, 44},
-			OUT: 44,
-		}),
-	)
-
-	cs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &mux4to1Circuit{})
-	// (4 - 1) + (2 + 1) + 1 == 7
-	assert.Equal(7, cs.GetNbConstraints())
 }
 
 // Map tests:
@@ -126,7 +106,7 @@ type mapCircuit struct {
 
 func (c *mapCircuit) Define(api frontend.API) error {
 
-	out := selector.Map(api, c.SEL,
+	out := Map(api, c.SEL,
 		[]frontend.Variable{c.K0, c.K1, c.K2, c.K3},
 		[]frontend.Variable{c.V0, c.V1, c.V2, c.V3})
 
@@ -143,7 +123,7 @@ type ignoredOutputMapCircuit struct {
 
 func (c *ignoredOutputMapCircuit) Define(api frontend.API) error {
 
-	_ = selector.Map(api, c.SEL,
+	_ = Map(api, c.SEL,
 		[]frontend.Variable{c.K0, c.K1},
 		[]frontend.Variable{c.V0, c.V1})
 

Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the comments and suggested edits. I wasn't able to push directly to the PR.

It looks very good, but please incorporate the changes and I'll merge.

@weijiguo
Copy link
Contributor Author

It looks very good, but please incorporate the changes and I'll merge.

All incorporated except for the range clause. Also, after the suggested editing, there is no test remaining in the multiplexer2_test.go file. If the current test coverage is sufficient, maybe we can drop it? binarymux (mux2) and dotProductCircuit (mux3) were originally coded for benchmarking purpose.

In case you see further changes necessary, I have invited you to the repos https://github.com/lightec-xyz/gnark and https://github.com/lightec-xyz/gnark-crypto Please feel free to make changes. @ivokub

@weijiguo weijiguo requested a review from ivokub February 28, 2025 01:59
@ivokub
Copy link
Collaborator

ivokub commented Feb 28, 2025

It looks very good, but please incorporate the changes and I'll merge.

All incorporated except for the range clause. Also, after the suggested editing, there is no test remaining in the multiplexer2_test.go file. If the current test coverage is sufficient, maybe we can drop it? binarymux (mux2) and dotProductCircuit (mux3) were originally coded for benchmarking purpose.

In case you see further changes necessary, I have invited you to the repos https://github.com/lightec-xyz/gnark and https://github.com/lightec-xyz/gnark-crypto Please feel free to make changes. @ivokub

Thanks for the access. I dropped multiplexer2_test.go and Length field in test circuit which I think is not necessary and made some testing more difficult.

Imo there is still room for optimization, I don't have good idea right now though how to do it nicely. There is

	nbBits := binary.Len(n - 1)                                   // we use n-1 as sel is 0-indexed
	selBits := bits.ToBinary(api, sel, bits.WithNbDigits(nbBits)) // binary decomposition ensures sel < 2^nbBits

and

	bcmp := cmp.NewBoundedComparator(api, big.NewInt((1<<nbBits)-1), false)
	bcmp.AssertIsLessEq(sel, n-1)

and the second call duplicates a bit the first one, but instead it computes binary decomposition of (n-1)-sel. That second binary decomposition also requires adding checks for the bits and that the bits compose up to (n-1)-sel.

But in practice we could do the check directly on selBits. Actually there is already mechanism inside bits.ToBinary but it only checks against the constant scalar field modulus to enforce uniqueness, maybe we could modify it to allow using any constant value to compare against, not only the modulus, but its a bit more involved.

I'll wait for the CI and then merge. I created new issue for it #1434

@ivokub ivokub merged commit d2fa6c3 into Consensys:master Feb 28, 2025
3 checks passed
@weijiguo weijiguo deleted the pert/mux branch February 28, 2025 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants