From 1e208500ab2139899e9c42090eba912e54b73358 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:31:18 -0500 Subject: [PATCH 01/25] perf: single-elem pool for bls12-377 --- internal/gkr/bls12-377/gate_testing.go | 5 ++ internal/gkr/bls12-377/gkr.go | 98 ++++++++++++++++---------- internal/gkr/bls12-377/solver_hints.go | 5 +- 3 files changed, 68 insertions(+), 40 deletions(-) diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 9e5a3868f3..6088644ea0 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index b8ef9ea973..16e7af5c12 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -104,8 +104,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) + api.freeElements() } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +270,13 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +674,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +731,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +796,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +804,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index c977d4997a..96b6636151 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From 4c69b6518416ed6e5b1af41c89831d7ce9a1f9ab Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:47:13 -0500 Subject: [PATCH 02/25] build: generify changes --- .../backend/template/gkr/gate_testing.go.tmpl | 5 + .../backend/template/gkr/gkr.go.tmpl | 96 +++++++++++-------- .../backend/template/gkr/solver_hints.go.tmpl | 3 +- internal/gkr/bls12-377/gkr.go | 2 - internal/gkr/bls12-377/solver_hints.go | 2 +- internal/gkr/bls12-381/gate_testing.go | 5 + internal/gkr/bls12-381/gkr.go | 96 +++++++++++-------- internal/gkr/bls12-381/solver_hints.go | 3 +- internal/gkr/bls24-315/gate_testing.go | 5 + internal/gkr/bls24-315/gkr.go | 96 +++++++++++-------- internal/gkr/bls24-315/solver_hints.go | 3 +- internal/gkr/bls24-317/gate_testing.go | 5 + internal/gkr/bls24-317/gkr.go | 96 +++++++++++-------- internal/gkr/bls24-317/solver_hints.go | 3 +- internal/gkr/bn254/gate_testing.go | 5 + internal/gkr/bn254/gkr.go | 96 +++++++++++-------- internal/gkr/bn254/solver_hints.go | 3 +- internal/gkr/bw6-633/gate_testing.go | 5 + internal/gkr/bw6-633/gkr.go | 96 +++++++++++-------- internal/gkr/bw6-633/solver_hints.go | 3 +- internal/gkr/bw6-761/gate_testing.go | 5 + internal/gkr/bw6-761/gkr.go | 96 +++++++++++-------- internal/gkr/bw6-761/solver_hints.go | 3 +- internal/gkr/small_rational/gate_testing.go | 5 + internal/gkr/small_rational/gkr.go | 96 +++++++++++-------- 25 files changed, 519 insertions(+), 314 deletions(-) diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index a782015cfa..7c24b0d27a 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -15,6 +15,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -130,6 +131,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -139,11 +141,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -157,6 +161,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make({{.FieldPackageName}}.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 16d5eb970b..c5c5b21dc2 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -97,8 +97,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []{{ .ElementType for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*{{ .ElementType }})) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*{{ .ElementType }})) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -230,7 +230,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]{{ .ElementType }}, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step {{ .ElementType }} + var ( + step {{ .ElementType }} + api gateAPI + ) res := make([]{{ .ElementType }}, degGJ) @@ -260,11 +263,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*{{ .ElementType }}) + summand := wire.Gate.Evaluate(&api, gateInput...).(*{{ .ElementType }}) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -663,6 +667,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]{{ .ElementType }}, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -720,52 +725,54 @@ func frToBigInts(dst []*big.Int, src []{{ .ElementType }}) { // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*{{ .ElementType }} + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod {{ .ElementType }} - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x {{ .ElementType }} @@ -783,7 +790,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -791,22 +798,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .E return f(api, inVar...).(*{{ .ElementType }}) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *{{ .ElementType }} { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new({{ .ElementType }})) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...{{ .ElementType }}) *{{ .ElementType }} // convertFunc turns f into a function that accepts and returns {{ .ElementType }}. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...{{ .ElementType }}) *{{ .ElementType }} { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *{{ .ElementType }} { +func (api *gateAPI) cast(v frontend.Variable) *{{ .ElementType }} { if x, ok := v.(*{{ .ElementType }}); ok { // fast path, no extra heap allocation return x } - var x {{ .ElementType }} + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 04873fdcb8..eebc96850f 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -97,7 +97,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 16e7af5c12..c31d447691 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -106,7 +106,6 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb } var api gateAPI gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) - api.freeElements() } evaluation.Mul(&evaluation, &gateEvaluation) @@ -275,7 +274,6 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } - api.freeElements() } mu.Lock() diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 96b6636151..6c504f7692 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index 5b281fd634..275ce2efb0 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 8f72898737..12b5aff144 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 81572a4ac4..372d3f2811 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 058b53cc06..c25c46bb4d 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7aee277ba4..c93b8a3c95 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 783cc964c8..aa7d9cd19d 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index ed418ff1b0..7aac990c1b 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 7c679216fc..c697f94a7e 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 234a327324..dc4fe325e4 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index e9311a3ea5..5d5260ee19 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 0174caa564..f5291406cb 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index f855222636..ab61f1ef43 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 8074b9621c..70d8aa7d4b 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 2c2bda2037..5f3acb6842 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 19d347d099..de4a5c49c2 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 0bae6258dc..f534002d83 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 099b015b02..f063ca4fa0 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 09e9c13f0f..41ebcaf4c1 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go index 93c4ca4191..11c60d8e9c 100644 --- a/internal/gkr/small_rational/gate_testing.go +++ b/internal/gkr/small_rational/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -117,6 +118,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -126,11 +128,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -144,6 +148,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(small_rational.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index d085c6305f..3be9191db4 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []small_rational.S for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*small_rational.SmallRational)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*small_rational.SmallRational)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]small_rational.SmallRational, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step small_rational.SmallRational + var ( + step small_rational.SmallRational + api gateAPI + ) res := make([]small_rational.SmallRational, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*small_rational.SmallRational) + summand := wire.Gate.Evaluate(&api, gateInput...).(*small_rational.SmallRational) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]small_rational.SmallRational, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*small_rational.SmallRational + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod small_rational.SmallRational - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x small_rational.SmallRational @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRation return f(api, inVar...).(*small_rational.SmallRational) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *small_rational.SmallRational { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(small_rational.SmallRational)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...small_rational.SmallRational) *small_rational.SmallRational // convertFunc turns f into a function that accepts and returns small_rational.SmallRational. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...small_rational.SmallRational) *small_rational.SmallRational { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *small_rational.SmallRational { +func (api *gateAPI) cast(v frontend.Variable) *small_rational.SmallRational { if x, ok := v.(*small_rational.SmallRational); ok { // fast path, no extra heap allocation return x } - var x small_rational.SmallRational + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } From b12c0ee95a5536865a0e9f99a657af0bba58e2d0 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:56:40 -0500 Subject: [PATCH 03/25] fix: api pointer receiver --- internal/generator/backend/template/gkr/solver_hints.go.tmpl | 2 +- internal/gkr/bls12-377/solver_hints.go | 2 +- internal/gkr/bls12-381/solver_hints.go | 2 +- internal/gkr/bls24-315/solver_hints.go | 2 +- internal/gkr/bls24-317/solver_hints.go | 2 +- internal/gkr/bn254/solver_hints.go | 2 +- internal/gkr/bw6-633/solver_hints.go | 2 +- internal/gkr/bw6-761/solver_hints.go | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index eebc96850f..82b5c8927d 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -111,7 +111,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 6c504f7692..96b6636151 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 372d3f2811..416eb334e8 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index aa7d9cd19d..ff0267ad5f 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index dc4fe325e4..f2ebc7a410 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index ab61f1ef43..048895e003 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index de4a5c49c2..1e2b9aae00 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 41ebcaf4c1..a64e8ad154 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From 466f16bbab16ca8754220073959dad62d86b9fcd Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 19:43:26 +0000 Subject: [PATCH 04/25] refactor: remove loadCs --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 32 ++----------------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 8a9381bfd7..751ecaa473 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -3,13 +3,10 @@ package gkr_mimc import ( "errors" "fmt" - "os" "slices" "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/plonk" - "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -116,32 +113,6 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { - f, err := os.Open(filename) - - if os.IsNotExist(err) { - // actually compile - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) - require.NoError(t, err) - f, err = os.Create(filename) - require.NoError(t, err) - defer f.Close() - _, err = cs.WriteTo(f) - require.NoError(t, err) - return cs - } - - defer f.Close() - require.NoError(t, err) - - cs := plonk.NewCS(ecc.BLS12_377) - - _, err = cs.ReadFrom(f) - require.NoError(t, err) - - return cs -} - func BenchmarkHashTree(b *testing.B) { const size = 1 << 15 // about 2 ^ 16 total hashes @@ -156,7 +127,8 @@ func BenchmarkHashTree(b *testing.B) { assignment.Leaves[i] = i } - cs := loadCs(b, "gkrmimc_hashtree.cs", &circuit) + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) From 57e95bcf7a621505abe3503a0a5ac1b087cec159 Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 20:01:43 +0000 Subject: [PATCH 05/25] perf: reset api in gkr solver --- internal/gkr/bls12-377/solver_hints.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 96b6636151..1b28a97da7 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From 6755e03ad4ef5c888dab63ad7ad405b1b5039c26 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 15:36:55 -0500 Subject: [PATCH 06/25] build: generify api reset --- internal/generator/backend/template/gkr/solver_hints.go.tmpl | 1 + internal/gkr/bls12-381/solver_hints.go | 1 + internal/gkr/bls24-315/solver_hints.go | 1 + internal/gkr/bls24-317/solver_hints.go | 1 + internal/gkr/bn254/solver_hints.go | 1 + internal/gkr/bw6-633/solver_hints.go | 1 + internal/gkr/bw6-761/solver_hints.go | 1 + 7 files changed, 7 insertions(+) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 82b5c8927d..bfa2d3114a 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -112,6 +112,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 416eb334e8..e576d3994d 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index ff0267ad5f..c122606692 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index f2ebc7a410..256d6cf9dc 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 048895e003..164d353e9e 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 1e2b9aae00..4c57f6e651 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index a64e8ad154..679fc6270f 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From b1de7358250e74baf1fad62c5aca5db3b6d4fac4 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 15:42:15 -0500 Subject: [PATCH 07/25] `hashTree` -> `merkleTree` --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 10 +++++----- std/permutation/gkr-mimc/gkr-mimc_test.go | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 751ecaa473..95956ff6c4 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -79,11 +79,11 @@ func TestGkrMiMCCompiles(t *testing.T) { fmt.Println(cs.GetNbConstraints(), "constraints") } -type hashTreeCircuit struct { +type merkleTreeCircuit struct { Leaves []frontend.Variable } -func (c hashTreeCircuit) Define(api frontend.API) error { +func (c merkleTreeCircuit) Define(api frontend.API) error { if len(c.Leaves) == 0 { return errors.New("no hashing to do") } @@ -113,13 +113,13 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func BenchmarkHashTree(b *testing.B) { +func BenchmarkMerkleTree(b *testing.B) { const size = 1 << 15 // about 2 ^ 16 total hashes - circuit := hashTreeCircuit{ + circuit := merkleTreeCircuit{ Leaves: make([]frontend.Variable, size), } - assignment := hashTreeCircuit{ + assignment := merkleTreeCircuit{ Leaves: make([]frontend.Variable, size), } diff --git a/std/permutation/gkr-mimc/gkr-mimc_test.go b/std/permutation/gkr-mimc/gkr-mimc_test.go index 93143b1279..6cc1bde714 100644 --- a/std/permutation/gkr-mimc/gkr-mimc_test.go +++ b/std/permutation/gkr-mimc/gkr-mimc_test.go @@ -11,11 +11,11 @@ import ( "github.com/stretchr/testify/require" ) -type hashTreeCircuit struct { +type merkleTreeCircuit struct { Leaves []frontend.Variable } -func (c hashTreeCircuit) Define(api frontend.API) error { +func (c merkleTreeCircuit) Define(api frontend.API) error { if len(c.Leaves) == 0 { return errors.New("no hashing to do") } @@ -43,7 +43,7 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func BenchmarkGkrPermutations(b *testing.B) { +func BenchmarkMerkleTree(b *testing.B) { circuit, assignment := hashTreeCircuits(50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) @@ -56,15 +56,15 @@ func BenchmarkGkrPermutations(b *testing.B) { require.NoError(b, err) } -func hashTreeCircuits(n int) (circuit, assignment hashTreeCircuit) { +func hashTreeCircuits(n int) (circuit, assignment merkleTreeCircuit) { leaves := make([]frontend.Variable, n) for i := range n { leaves[i] = i } - return hashTreeCircuit{ + return merkleTreeCircuit{ Leaves: make([]frontend.Variable, len(leaves)), - }, hashTreeCircuit{ + }, merkleTreeCircuit{ Leaves: leaves, } } From 6c4c5c665974c8140219a8f65012e7ca2093a559 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 15:55:28 -0500 Subject: [PATCH 08/25] docs: copilot-inspired explanation for `freeElements` --- internal/generator/backend/template/gkr/gkr.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/small_rational/gkr.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index c5c5b21dc2..c2d2901aab 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -798,7 +798,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ . return f(api, inVar...).(*{{ .ElementType }}) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index c31d447691..af149027cb 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 12b5aff144..9f2b371057 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index c93b8a3c95..b9dab1c6fe 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index c697f94a7e..ecb1a6bcb8 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index f5291406cb..77cb8d085b 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 5f3acb6842..a443549110 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index f063ca4fa0..0ae90d3c41 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 3be9191db4..0557f938df 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRatio return f(api, inVar...).(*small_rational.SmallRational) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } From 8d0b353e41734e9c6f740e1f93cba831850c97fd Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 22:55:44 +0000 Subject: [PATCH 09/25] perf: more pool freeing --- internal/generator/backend/template/gkr/gkr.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/small_rational/gkr.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index c2d2901aab..2131aea723 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -267,8 +267,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index af149027cb..c28e3df8e3 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 9f2b371057..5c87cdd85d 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index b9dab1c6fe..acbebf56cf 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index ecb1a6bcb8..3d4a80d557 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 77cb8d085b..bed1d89329 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index a443549110..4962d1a4f9 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 0ae90d3c41..6ea38abf29 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 0557f938df..e65c94752e 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { From 5a048b6b4fefe37ba271210086b3b33ccd47ad7b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 14:10:47 -0500 Subject: [PATCH 10/25] perf: dedicated exp function --- .../backend/template/gkr/gkr.go.tmpl | 26 ++++++++++++++++++ internal/gkr/bls12-377/gkr.go | 27 +++++++++++++++++++ internal/gkr/bls12-381/gkr.go | 26 ++++++++++++++++++ internal/gkr/bls24-315/gkr.go | 26 ++++++++++++++++++ internal/gkr/bls24-317/gkr.go | 26 ++++++++++++++++++ internal/gkr/bn254/gkr.go | 26 ++++++++++++++++++ internal/gkr/bw6-633/gkr.go | 26 ++++++++++++++++++ internal/gkr/bw6-761/gkr.go | 26 ++++++++++++++++++ internal/gkr/engine_hints.go | 12 ++++++--- internal/gkr/gkr.go | 23 +++++++++++++++- internal/gkr/small_rational/gkr.go | 26 ++++++++++++++++++ std/gkrapi/compile.go | 2 +- std/gkrapi/gkr/types.go | 3 +++ std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 5 +++- std/permutation/gkr-mimc/gkr-mimc.go | 7 +---- .../gkr-poseidon2/gkr-poseidon2_test.go | 16 +++++++++++ 16 files changed, 291 insertions(+), 12 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 2131aea723..32accd6dca 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -790,6 +790,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { + var res *{{ .ElementType }} + x := api.cast(i) + + if n % 2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + n /= 2 + + // square and multiply + for n != 0 { + res.Mul(res, res) + + if n % 2 != 0 { + res.Mul(res, x) + } + + n /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index c28e3df8e3..d96f91d7d1 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -794,6 +795,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + if e == 0 { + return 1 + } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) + + res := api.newElement() + x := api.cast(i) + *res = *x + + // square and multiply + for n != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + n-- + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 5c87cdd85d..0941a05deb 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index acbebf56cf..7f70f6a12c 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 3d4a80d557..0320deb5de 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index bed1d89329..a380373951 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4962d1a4f9..66d957f8f4 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 6ea38abf29..35d6f86bad 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 8c8bc1b797..0e272f12a4 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ *big.Int } +type gateAPI struct{ mod *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,7 +178,13 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce + return &x +} + +func (g gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + x := utils.FromInterface(i) + x.Exp(&x, big.NewInt(int64(e)), g.mod) return &x } diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 955ad8a354..6b8b5ebd4f 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(FrontendApiWrapper{api}, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,3 +383,24 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } + +// FrontendApiWrapper implements additional functions to satisfy the gkr.GateAPI interface. +type FrontendApiWrapper struct { + frontend.API +} + +func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variable { + res := frontend.Variable(1) + if e%2 != 0 { + res = i + } + e /= 2 + for e != 0 { + res = api.Mul(res, res) + if e%2 != 0 { + res = api.Mul(res, i) + } + e /= 2 + } + return res +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e65c94752e..73cf93a146 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *small_rational.SmallRational + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 64205b80b8..5aca90cad2 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(gadget.FrontendApiWrapper{API: api}, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index af8a40fcd6..c9b8678b7b 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -35,6 +35,9 @@ type GateAPI interface { // Mul returns res = i1 * i2 * ... in Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable + // Exp returns res = iᵉ + Exp(i frontend.Variable, e uint8) frontend.Variable + // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 95956ff6c4..3ad685e791 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -6,6 +6,7 @@ import ( "slices" "testing" + "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" @@ -29,7 +30,9 @@ func TestGkrMiMC(t *testing.T) { In: slices.Clone(vals[:length*2]), } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment)) + allCurves := gnark.Curves() + allCurves = []ecc.ID{ecc.BLS12_377} // TODO REMOVE + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(allCurves[0], allCurves[1:]...)) } } diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 266ee00e67..c98d100303 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,12 +208,7 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Mul(t, s) // s¹⁶ × s + return api.Exp(api.Add(in[0], in[1], key), 17) } } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index b224bf1414..b76024c888 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -2,6 +2,8 @@ package gkr_poseidon2 import ( "fmt" + "math/bits" + "strings" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -79,3 +81,17 @@ func BenchmarkGkrCompressions(b *testing.B) { _, err = cs.Solve(witness) require.NoError(b, err) } + +func TestGenerateTable(t *testing.T) { + var sb strings.Builder + for n := range 256 { + if n%16 == 0 { + sb.WriteString("\"+\n\"") + } + b := uint8(n) + b <<= bits.LeadingZeros8(b) + b = bits.Reverse8(b) + sb.WriteString(fmt.Sprintf("\\x%x", b)) + } + fmt.Println(sb.String()) +} From c6b830c3fdf9137884b4015b4bdff72e3658884a Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 14:14:37 -0500 Subject: [PATCH 11/25] build: generify exp changes --- .../backend/template/gkr/gkr.go.tmpl | 23 ++++++++++--------- internal/gkr/bls12-381/gkr.go | 19 +++++++-------- internal/gkr/bls24-315/gkr.go | 19 +++++++-------- internal/gkr/bls24-317/gkr.go | 19 +++++++-------- internal/gkr/bn254/gkr.go | 19 +++++++-------- internal/gkr/bw6-633/gkr.go | 19 +++++++-------- internal/gkr/bw6-761/gkr.go | 19 +++++++-------- internal/gkr/small_rational/gkr.go | 19 +++++++-------- 8 files changed, 82 insertions(+), 74 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 32accd6dca..0d97343a33 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -6,6 +6,7 @@ import ( fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "math/big" + "math/bits" "strconv" "sync" "github.com/consensys/gnark/frontend" @@ -790,27 +791,27 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { - var res *{{ .ElementType }} - x := api.cast(i) - - if n % 2 == 0 { - res = api.cast(1) - } else { - *res = *x +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - n /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply for n != 0 { res.Mul(res, res) - if n % 2 != 0 { + if e % 2 != 0 { res.Mul(res, x) } - n /= 2 + e /= 2 + n-- } return res diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0941a05deb..05560b9014 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7f70f6a12c..8f6feb48b6 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 0320deb5de..495658b89d 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a380373951..01d2a5e9de 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 66d957f8f4..af578e59d7 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 35d6f86bad..63dabac899 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 73cf93a146..ffd730e69e 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *small_rational.SmallRational - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res From 7bb78e7e8e9dcf8653021b8be566bad1d7e7cd82 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 14:18:48 -0500 Subject: [PATCH 12/25] fix: test for all curves --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 3ad685e791..1cd6579e0c 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -31,8 +32,12 @@ func TestGkrMiMC(t *testing.T) { } allCurves := gnark.Curves() - allCurves = []ecc.ID{ecc.BLS12_377} // TODO REMOVE - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(allCurves[0], allCurves[1:]...)) + test.NewAssert(t).CheckCircuit( + circuit, + test.WithValidAssignment(assignment), + test.WithCurves(allCurves[0], allCurves[1:]...), + test.WithBackends(backend.PLONK), + ) } } From 58a0fdb0083252917943b1e4f432b14ceee834dc Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Mon, 22 Sep 2025 21:33:41 +0000 Subject: [PATCH 13/25] perf: fastpath for ^17 --- internal/gkr/bls12-377/gkr.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index d96f91d7d1..efaa8f9a6f 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -795,15 +795,29 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +var seventeen = big.NewInt(17) + func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + + res := api.newElement() + x := api.cast(i) + + if e == 17 { + res.Mul(x, x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + res.Mul(res, x) // x¹⁷ + + return res + } + if e == 0 { return 1 } n := bits.Len8(e) - 1 e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) *res = *x // square and multiply From 4b2c10207523d766c4ea3a45711f6e745a78b687 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:36:56 -0500 Subject: [PATCH 14/25] fix: exp --- internal/gkr/gkr.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 6b8b5ebd4f..b115c56aa6 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -391,16 +391,13 @@ type FrontendApiWrapper struct { func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variable { res := frontend.Variable(1) - if e%2 != 0 { - res = i - } - e /= 2 - for e != 0 { + + for range 8 { res = api.Mul(res, res) - if e%2 != 0 { + if e%128 != 0 { res = api.Mul(res, i) } - e /= 2 + e <<= 1 } return res } From 30c39b4fee11c6fa13ba67474451ea0d056ebd8d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:38:31 -0500 Subject: [PATCH 15/25] fix: exp. really --- internal/gkr/gkr.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index b115c56aa6..ee43864d7a 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -394,7 +394,7 @@ func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variabl for range 8 { res = api.Mul(res, res) - if e%128 != 0 { + if e&128 != 0 { res = api.Mul(res, i) } e <<= 1 From 6c7a7fccb4365761f517044555c32d755b29f40b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:43:43 -0500 Subject: [PATCH 16/25] test: modernize benchmark --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 1cd6579e0c..569201ffa1 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -2,7 +2,6 @@ package gkr_mimc import ( "errors" - "fmt" "slices" "testing" @@ -77,16 +76,6 @@ func (c *testGkrMiMCCircuit) Define(api frontend.API) error { return nil } -func TestGkrMiMCCompiles(t *testing.T) { - const n = 52000 - circuit := testGkrMiMCCircuit{ - In: make([]frontend.Variable, n), - } - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(27_000_000)) - require.NoError(t, err) - fmt.Println(cs.GetNbConstraints(), "constraints") -} - type merkleTreeCircuit struct { Leaves []frontend.Variable } @@ -141,5 +130,9 @@ func BenchmarkMerkleTree(b *testing.B) { w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) - require.NoError(b, cs.IsSolved(w)) + for b.Loop() { + s, err := cs.Solve(w) + require.NoError(b, err) + _ = s + } } From 266c5f6d057498a714b715fda19bd8b22bae968b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:44:53 -0500 Subject: [PATCH 17/25] Revert "build: generify exp changes" This reverts commit c6b830c3fdf9137884b4015b4bdff72e3658884a. --- .../backend/template/gkr/gkr.go.tmpl | 23 +++++++++---------- internal/gkr/bls12-381/gkr.go | 19 ++++++++------- internal/gkr/bls24-315/gkr.go | 19 ++++++++------- internal/gkr/bls24-317/gkr.go | 19 ++++++++------- internal/gkr/bn254/gkr.go | 19 ++++++++------- internal/gkr/bw6-633/gkr.go | 19 ++++++++------- internal/gkr/bw6-761/gkr.go | 19 ++++++++------- internal/gkr/small_rational/gkr.go | 19 ++++++++------- 8 files changed, 74 insertions(+), 82 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 0d97343a33..32accd6dca 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -6,7 +6,6 @@ import ( fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "math/big" - "math/bits" "strconv" "sync" "github.com/consensys/gnark/frontend" @@ -791,27 +790,27 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 +func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { + var res *{{ .ElementType }} + x := api.cast(i) + + if n % 2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + n /= 2 // square and multiply for n != 0 { res.Mul(res, res) - if e % 2 != 0 { + if n % 2 != 0 { res.Mul(res, x) } - e /= 2 - n-- + n /= 2 } return res diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 05560b9014..0941a05deb 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 8f6feb48b6..7f70f6a12c 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 495658b89d..0320deb5de 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 01d2a5e9de..a380373951 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index af578e59d7..66d957f8f4 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 63dabac899..35d6f86bad 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index ffd730e69e..73cf93a146 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *small_rational.SmallRational + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res From 5dd1d61d207a854c477b53a886d99f379c28b956 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:47:02 -0500 Subject: [PATCH 18/25] revert: exp func --- .../backend/template/gkr/gkr.go.tmpl | 26 ------------ internal/gkr/bls12-377/gkr.go | 41 ------------------- internal/gkr/bls12-381/gkr.go | 26 ------------ internal/gkr/bls24-315/gkr.go | 26 ------------ internal/gkr/bls24-317/gkr.go | 26 ------------ internal/gkr/bn254/gkr.go | 26 ------------ internal/gkr/bw6-633/gkr.go | 26 ------------ internal/gkr/bw6-761/gkr.go | 26 ------------ internal/gkr/engine_hints.go | 12 ++---- internal/gkr/small_rational/gkr.go | 26 ------------ std/gkrapi/compile.go | 2 +- std/gkrapi/gkr/types.go | 3 -- std/permutation/gkr-mimc/gkr-mimc.go | 7 +++- .../gkr-poseidon2/gkr-poseidon2_test.go | 16 -------- 14 files changed, 10 insertions(+), 279 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 32accd6dca..2131aea723 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -790,32 +790,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { - var res *{{ .ElementType }} - x := api.cast(i) - - if n % 2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - n /= 2 - - // square and multiply - for n != 0 { - res.Mul(res, res) - - if n % 2 != 0 { - res.Mul(res, x) - } - - n /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index efaa8f9a6f..c28e3df8e3 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -795,46 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -var seventeen = big.NewInt(17) - -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - - res := api.newElement() - x := api.cast(i) - - if e == 17 { - res.Mul(x, x) // x² - res.Mul(res, res) // x⁴ - res.Mul(res, res) // x⁸ - res.Mul(res, res) // x¹⁶ - res.Mul(res, x) // x¹⁷ - - return res - } - - if e == 0 { - return 1 - } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - - *res = *x - - // square and multiply - for n != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - n-- - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0941a05deb..5c87cdd85d 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7f70f6a12c..acbebf56cf 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 0320deb5de..3d4a80d557 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a380373951..bed1d89329 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 66d957f8f4..4962d1a4f9 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 35d6f86bad..6ea38abf29 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 0e272f12a4..8c8bc1b797 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ mod *big.Int } +type gateAPI struct{ *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.mod) // reduce + x.Mod(&x, g.Int) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,13 +178,7 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.mod) // reduce - return &x -} - -func (g gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - x := utils.FromInterface(i) - x.Exp(&x, big.NewInt(int64(e)), g.mod) + x.Mod(&x, g.Int) // reduce return &x } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 73cf93a146..e65c94752e 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *small_rational.SmallRational - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 5aca90cad2..64205b80b8 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(gadget.FrontendApiWrapper{API: api}, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index c9b8678b7b..af8a40fcd6 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -35,9 +35,6 @@ type GateAPI interface { // Mul returns res = i1 * i2 * ... in Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable - // Exp returns res = iᵉ - Exp(i frontend.Variable, e uint8) frontend.Variable - // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index c98d100303..266ee00e67 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,7 +208,12 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - return api.Exp(api.Add(in[0], in[1], key), 17) + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Mul(t, s) // s¹⁶ × s } } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index b76024c888..b224bf1414 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -2,8 +2,6 @@ package gkr_poseidon2 import ( "fmt" - "math/bits" - "strings" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -81,17 +79,3 @@ func BenchmarkGkrCompressions(b *testing.B) { _, err = cs.Solve(witness) require.NoError(b, err) } - -func TestGenerateTable(t *testing.T) { - var sb strings.Builder - for n := range 256 { - if n%16 == 0 { - sb.WriteString("\"+\n\"") - } - b := uint8(n) - b <<= bits.LeadingZeros8(b) - b = bits.Reverse8(b) - sb.WriteString(fmt.Sprintf("\\x%x", b)) - } - fmt.Println(sb.String()) -} From 1374bab9805cdb95bc199db8af8f98d188d59131 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:49:36 -0500 Subject: [PATCH 19/25] refactor: modulus name --- internal/gkr/engine_hints.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 8c8bc1b797..9e44dc7a3e 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ *big.Int } +type gateAPI struct{ mod *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,7 +178,7 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce return &x } From 9959f8a45cbc26be5dfb1d94a1c48880fe58d5be Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:52:05 -0500 Subject: [PATCH 20/25] revert: remove FrontendApiWrapper --- internal/gkr/gkr.go | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index ee43864d7a..955ad8a354 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(FrontendApiWrapper{api}, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,21 +383,3 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } - -// FrontendApiWrapper implements additional functions to satisfy the gkr.GateAPI interface. -type FrontendApiWrapper struct { - frontend.API -} - -func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variable { - res := frontend.Variable(1) - - for range 8 { - res = api.Mul(res, res) - if e&128 != 0 { - res = api.Mul(res, i) - } - e <<= 1 - } - return res -} From 271cff9225786d4e1f27046a090fee99cc8f3e68 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 23 Sep 2025 18:12:53 -0500 Subject: [PATCH 21/25] perf: dedicated exp17 func --- .../generator/backend/template/gkr/gkr.go.tmpl | 10 ++++++++++ internal/gkr/bls12-377/gkr.go | 10 ++++++++++ internal/gkr/bls12-381/gkr.go | 10 ++++++++++ internal/gkr/bls24-315/gkr.go | 10 ++++++++++ internal/gkr/bls24-317/gkr.go | 10 ++++++++++ internal/gkr/bn254/gkr.go | 10 ++++++++++ internal/gkr/bw6-633/gkr.go | 10 ++++++++++ internal/gkr/bw6-761/gkr.go | 10 ++++++++++ internal/gkr/engine_hints.go | 7 +++++++ internal/gkr/gkr.go | 14 +++++++++++++- internal/gkr/small_rational/gkr.go | 10 ++++++++++ std/gkrapi/compile.go | 2 +- std/gkrapi/gkr/types.go | 2 ++ std/permutation/gkr-mimc/gkr-mimc.go | 7 +------ 14 files changed, 114 insertions(+), 8 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 2131aea723..31819961e0 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -772,6 +772,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x {{ .ElementType }} diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index c28e3df8e3..996f4a6b4b 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 5c87cdd85d..863a2f71c8 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index acbebf56cf..dc6eb854ad 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 3d4a80d557..6333662f83 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index bed1d89329..f2361557c6 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4962d1a4f9..e8b1f9d27d 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 6ea38abf29..27315ad58d 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 9e44dc7a3e..35f8cd5c7c 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -182,6 +182,13 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend return &x } +func (g gateAPI) Exp17(i frontend.Variable) frontend.Variable { + x := utils.FromInterface(i) + var res big.Int + res.Exp(&x, big.NewInt(17), g.mod) + return &res +} + func (g gateAPI) Println(a ...frontend.Variable) { strings := make([]string, len(a)) for i := range a { diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 955ad8a354..dc2e0c9fd3 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(FrontendAPIWrapper{api}, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,3 +383,15 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } + +type FrontendAPIWrapper struct { + frontend.API +} + +func (api FrontendAPIWrapper) Exp17(i frontend.Variable) frontend.Variable { + res := api.Mul(i, i) // i^2 + res = api.Mul(res, res) // i^4 + res = api.Mul(res, res) // i^8 + res = api.Mul(res, res) // i^16 + return api.Mul(res, i) // i^17 +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e65c94752e..a884902d4c 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x small_rational.SmallRational diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 64205b80b8..d8c0985ae5 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(gadget.FrontendAPIWrapper{API: api}, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index af8a40fcd6..5acc2cda9f 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -38,6 +38,8 @@ type GateAPI interface { // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) + + Exp17(a frontend.Variable) frontend.Variable } // GateFunction is a function that evaluates a polynomial over its inputs diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 266ee00e67..2d0e4ada69 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,12 +208,7 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Mul(t, s) // s¹⁶ × s + return api.Exp17(api.Add(in[0], in[1], key)) } } From 1519d5276db1919b78a6f4627880e9dd439168ab Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 24 Sep 2025 10:02:13 -0500 Subject: [PATCH 22/25] perf: SumExp17 --- .../backend/template/gkr/gkr.go.tmpl | 12 +++++++--- internal/gkr/bls12-377/gkr.go | 23 +++++++++++++------ internal/gkr/bls12-381/gkr.go | 12 +++++++--- internal/gkr/bls24-315/gkr.go | 12 +++++++--- internal/gkr/bls24-317/gkr.go | 12 +++++++--- internal/gkr/bn254/gkr.go | 12 +++++++--- internal/gkr/bw6-633/gkr.go | 12 +++++++--- internal/gkr/bw6-761/gkr.go | 12 +++++++--- internal/gkr/engine_hints.go | 10 +++++--- internal/gkr/gkr.go | 3 ++- internal/gkr/small_rational/gkr.go | 12 +++++++--- std/gkrapi/gkr/types.go | 2 +- std/permutation/gkr-mimc/gkr-mimc.go | 2 +- 13 files changed, 99 insertions(+), 37 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 31819961e0..bf0fb4ecd9 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -772,9 +772,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a,b,c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 996f4a6b4b..7925291eeb 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -776,14 +776,23 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + if _, err := x.SetInterface(c); err != nil { // a, b are expected to be *fr.Element but not c + panic(err) + } + + x.Add(&x, api.cast(a)) + x.Add(&x, api.cast(b)) + res := api.newElement() - x := api.cast(i) - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + + res.Mul(&x, &x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, &x) // x^17 } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 863a2f71c8..0b34a40e23 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index dc6eb854ad..a0a059e6df 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 6333662f83..42f1e5e56c 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index f2361557c6..a84276a35a 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index e8b1f9d27d..4daeb83900 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 27315ad58d..9821871edd 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 35f8cd5c7c..9d53255f4e 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -182,9 +182,13 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend return &x } -func (g gateAPI) Exp17(i frontend.Variable) frontend.Variable { - x := utils.FromInterface(i) - var res big.Int +func (g gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := utils.FromInterface(a) + res := utils.FromInterface(b) + + x.Add(&x, &res) + res = utils.FromInterface(c) + x.Add(&x, &res) res.Exp(&x, big.NewInt(17), g.mod) return &res } diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index dc2e0c9fd3..b2ea42e8ba 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -388,7 +388,8 @@ type FrontendAPIWrapper struct { frontend.API } -func (api FrontendAPIWrapper) Exp17(i frontend.Variable) frontend.Variable { +func (api FrontendAPIWrapper) SumExp17(a, b, c frontend.Variable) frontend.Variable { + i := api.Add(a, b, c) res := api.Mul(i, i) // i^2 res = api.Mul(res, res) // i^4 res = api.Mul(res, res) // i^8 diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index a884902d4c..37f9490950 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index 5acc2cda9f..3c03f75df1 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -39,7 +39,7 @@ type GateAPI interface { // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) - Exp17(a frontend.Variable) frontend.Variable + SumExp17(a, b, c frontend.Variable) frontend.Variable } // GateFunction is a function that evaluates a polynomial over its inputs diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 2d0e4ada69..1fa4cbce93 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,7 +208,7 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - return api.Exp17(api.Add(in[0], in[1], key)) + return api.SumExp17(in[0], in[1], key) } } From 2e3be9f26ca8318d07f019d99a4615f67c81549c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 24 Sep 2025 10:38:14 -0500 Subject: [PATCH 23/25] perf: cache key as fr.Element --- internal/gkr/bls12-377/gkr.go | 8 ++------ std/permutation/gkr-mimc/gkr-mimc.go | 14 +++++++++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 7925291eeb..13923e0983 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -779,12 +779,8 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { var x fr.Element - if _, err := x.SetInterface(c); err != nil { // a, b are expected to be *fr.Element but not c - panic(err) - } - - x.Add(&x, api.cast(a)) - x.Add(&x, api.cast(b)) + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) res := api.newElement() diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 1fa4cbce93..b2c2b7527e 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -5,6 +5,7 @@ import ( "math/big" "github.com/consensys/gnark-crypto/ecc" + frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" @@ -204,11 +205,22 @@ func addPow7Add(key *big.Int) gkr.GateFunction { // addPow17: (in[0]+in[1]+key)¹⁷ func addPow17(key *big.Int) gkr.GateFunction { + var cachedKey frontend.Variable return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") } - return api.SumExp17(in[0], in[1], key) + if cachedKey == nil { + if _, ok := in[0].(*frBls12377.Element); ok { + var ck frBls12377.Element + ck.SetBigInt(key) + cachedKey = &ck + } else { + return api.SumExp17(in[0], in[1], key) + } + } + + return api.SumExp17(in[0], in[1], cachedKey) } } From a05f8d63e757d58cc41cbec6d4d7db5237e9f015 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 24 Sep 2025 10:59:08 -0500 Subject: [PATCH 24/25] build: generify --- .../backend/template/gkr/gkr.go.tmpl | 23 +++++++++---------- internal/gkr/bls12-377/gkr.go | 10 ++++---- internal/gkr/bls12-381/gkr.go | 23 +++++++++---------- internal/gkr/bls24-315/gkr.go | 23 +++++++++---------- internal/gkr/bls24-317/gkr.go | 23 +++++++++---------- internal/gkr/bn254/gkr.go | 23 +++++++++---------- internal/gkr/bw6-633/gkr.go | 23 +++++++++---------- internal/gkr/bw6-761/gkr.go | 23 +++++++++---------- internal/gkr/small_rational/gkr.go | 23 +++++++++---------- 9 files changed, 93 insertions(+), 101 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index bf0fb4ecd9..12b55c8d30 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -773,19 +773,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a,b,c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x {{ .ElementType }} + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 13923e0983..8f7532228e 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -784,11 +784,11 @@ func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { res := api.newElement() - res.Mul(&x, &x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, &x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0b34a40e23..b4bdff0594 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index a0a059e6df..c095671e76 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 42f1e5e56c..bfa2fb778d 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a84276a35a..d075ced6cd 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4daeb83900..e6932181a6 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 9821871edd..acb157dd5f 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 37f9490950..427bac1877 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x small_rational.SmallRational + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { From f6613ce75532aaac67454077c5bd30a4591c1481 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 30 Sep 2025 16:56:24 -0500 Subject: [PATCH 25/25] perf: store keys as fr.Elements instead of big.Int --- std/permutation/gkr-mimc/gkr-mimc.go | 100 ++++++++++++++++++--------- 1 file changed, 68 insertions(+), 32 deletions(-) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index b2c2b7527e..1645552881 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -2,16 +2,21 @@ package gkr_mimc import ( "fmt" - "math/big" "github.com/consensys/gnark-crypto/ecc" - frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + frbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + frbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + frbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + frbls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + frbn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + frbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + frbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" @@ -95,7 +100,7 @@ func RegisterGates(curves ...ecc.ID) error { return err } gateNamer := newGateNamer(curve) - var lastLayerSBox, nonLastLayerSBox func(*big.Int) gkr.GateFunction + var lastLayerSBox, nonLastLayerSBox func(frontend.Variable) gkr.GateFunction switch deg { case 5: lastLayerSBox = addPow5Add @@ -111,12 +116,12 @@ func RegisterGates(curves ...ecc.ID) error { } for i := range len(constants) - 1 { - if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if _, err = gkrgates.Register(nonLastLayerSBox(constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) } } - if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if _, err = gkrgates.Register(lastLayerSBox(constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) } } @@ -124,23 +129,65 @@ func RegisterGates(curves ...ecc.ID) error { } // getParams returns the parameters for the MiMC encryption function for the given curve. -// It also returns the degree of the s-Box -func getParams(curve ecc.ID) ([]big.Int, int, error) { +// It also returns the degree of the s-Box. +func getParams(curve ecc.ID) ([]frontend.Variable, int, error) { switch curve { case ecc.BN254: - return bn254.GetConstants(), 5, nil + c := bn254.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbn254.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS12_381: - return bls12381.GetConstants(), 5, nil + c := bls12381.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls12381.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS12_377: - return bls12377.GetConstants(), 17, nil + c := bls12377.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls12377.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 17, nil case ecc.BLS24_315: - return bls24315.GetConstants(), 5, nil + c := bls24315.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls24315.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS24_317: - return bls24317.GetConstants(), 7, nil + c := bls24317.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls24317.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 7, nil case ecc.BW6_633: - return bw6633.GetConstants(), 5, nil + c := bw6633.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbw6633.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BW6_761: - return bw6761.GetConstants(), 5, nil + c := bw6761.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbw6761.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil default: return nil, -1, fmt.Errorf("unsupported curve ID: %s", curve) } @@ -155,7 +202,7 @@ func (n gateNamer) round(i int) gkr.GateName { return gkr.GateName(fmt.Sprintf("%s%d", string(n), i)) } -func addPow5(key *big.Int) gkr.GateFunction { +func addPow5(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") @@ -167,7 +214,7 @@ func addPow5(key *big.Int) gkr.GateFunction { } // addPow5Add: (in[0]+in[1]+key)⁵ + 2*in[0] + in[2] -func addPow5Add(key *big.Int) gkr.GateFunction { +func addPow5Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") @@ -180,7 +227,7 @@ func addPow5Add(key *big.Int) gkr.GateFunction { } } -func addPow7(key *big.Int) gkr.GateFunction { +func addPow7(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") @@ -192,7 +239,7 @@ func addPow7(key *big.Int) gkr.GateFunction { } // addPow7Add: (in[0]+in[1]+key)⁷ + 2*in[0] + in[2] -func addPow7Add(key *big.Int) gkr.GateFunction { +func addPow7Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") @@ -204,28 +251,17 @@ func addPow7Add(key *big.Int) gkr.GateFunction { } // addPow17: (in[0]+in[1]+key)¹⁷ -func addPow17(key *big.Int) gkr.GateFunction { - var cachedKey frontend.Variable +func addPow17(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") } - if cachedKey == nil { - if _, ok := in[0].(*frBls12377.Element); ok { - var ck frBls12377.Element - ck.SetBigInt(key) - cachedKey = &ck - } else { - return api.SumExp17(in[0], in[1], key) - } - } - - return api.SumExp17(in[0], in[1], cachedKey) + return api.SumExp17(in[0], in[1], key) } } // addPow17Add: (in[0]+in[1]+key)¹⁷ + in[0] + in[2] -func addPow17Add(key *big.Int) gkr.GateFunction { +func addPow17Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input")