diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index a782015cfa..7c24b0d27a 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -15,6 +15,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -130,6 +131,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -139,11 +141,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -157,6 +161,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make({{.FieldPackageName}}.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 16d5eb970b..12b55c8d30 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -97,8 +97,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []{{ .ElementType for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*{{ .ElementType }})) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*{{ .ElementType }})) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -230,7 +230,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]{{ .ElementType }}, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step {{ .ElementType }} + var ( + step {{ .ElementType }} + api gateAPI + ) res := make([]{{ .ElementType }}, degGJ) @@ -260,10 +263,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*{{ .ElementType }}) + summand := wire.Gate.Evaluate(&api, gateInput...).(*{{ .ElementType }}) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -663,6 +667,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]{{ .ElementType }}, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -720,52 +725,69 @@ func frToBigInts(dst []*big.Int, src []{{ .ElementType }}) { // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*{{ .ElementType }} + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod {{ .ElementType }} - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a,b,c frontend.Variable) frontend.Variable { + var x {{ .ElementType }} + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x {{ .ElementType }} @@ -783,7 +805,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -791,22 +813,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .E return f(api, inVar...).(*{{ .ElementType }}) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *{{ .ElementType }} { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new({{ .ElementType }})) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...{{ .ElementType }}) *{{ .ElementType }} // convertFunc turns f into a function that accepts and returns {{ .ElementType }}. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...{{ .ElementType }}) *{{ .ElementType }} { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *{{ .ElementType }} { +func (api *gateAPI) cast(v frontend.Variable) *{{ .ElementType }} { if x, ok := v.(*{{ .ElementType }}); ok { // fast path, no extra heap allocation return x } - var x {{ .ElementType }} + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 04873fdcb8..bfa2d3114a 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -97,7 +97,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -110,7 +111,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 9e5a3868f3..6088644ea0 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index b8ef9ea973..8f7532228e 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index c977d4997a..1b28a97da7 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index 5b281fd634..275ce2efb0 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 8f72898737..b4bdff0594 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 81572a4ac4..e576d3994d 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 058b53cc06..c25c46bb4d 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7aee277ba4..c095671e76 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 783cc964c8..c122606692 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index ed418ff1b0..7aac990c1b 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 7c679216fc..bfa2fb778d 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 234a327324..256d6cf9dc 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index e9311a3ea5..5d5260ee19 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 0174caa564..d075ced6cd 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index f855222636..164d353e9e 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 8074b9621c..70d8aa7d4b 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 2c2bda2037..e6932181a6 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 19d347d099..4c57f6e651 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 0bae6258dc..f534002d83 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 099b015b02..acb157dd5f 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 09e9c13f0f..679fc6270f 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 8c8bc1b797..9d53255f4e 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ *big.Int } +type gateAPI struct{ mod *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,10 +178,21 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce return &x } +func (g gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := utils.FromInterface(a) + res := utils.FromInterface(b) + + x.Add(&x, &res) + res = utils.FromInterface(c) + x.Add(&x, &res) + res.Exp(&x, big.NewInt(17), g.mod) + return &res +} + func (g gateAPI) Println(a ...frontend.Variable) { strings := make([]string, len(a)) for i := range a { diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 955ad8a354..b2ea42e8ba 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(FrontendAPIWrapper{api}, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,3 +383,16 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } + +type FrontendAPIWrapper struct { + frontend.API +} + +func (api FrontendAPIWrapper) SumExp17(a, b, c frontend.Variable) frontend.Variable { + i := api.Add(a, b, c) + res := api.Mul(i, i) // i^2 + res = api.Mul(res, res) // i^4 + res = api.Mul(res, res) // i^8 + res = api.Mul(res, res) // i^16 + return api.Mul(res, i) // i^17 +} diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go index 93c4ca4191..11c60d8e9c 100644 --- a/internal/gkr/small_rational/gate_testing.go +++ b/internal/gkr/small_rational/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -117,6 +118,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -126,11 +128,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -144,6 +148,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(small_rational.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index d085c6305f..427bac1877 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []small_rational.S for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*small_rational.SmallRational)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*small_rational.SmallRational)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]small_rational.SmallRational, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step small_rational.SmallRational + var ( + step small_rational.SmallRational + api gateAPI + ) res := make([]small_rational.SmallRational, degGJ) @@ -266,10 +269,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*small_rational.SmallRational) + summand := wire.Gate.Evaluate(&api, gateInput...).(*small_rational.SmallRational) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]small_rational.SmallRational, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,69 @@ func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*small_rational.SmallRational + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod small_rational.SmallRational - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res +} + +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x small_rational.SmallRational + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() + + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x small_rational.SmallRational @@ -787,7 +809,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +817,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRation return f(api, inVar...).(*small_rational.SmallRational) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *small_rational.SmallRational { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(small_rational.SmallRational)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...small_rational.SmallRational) *small_rational.SmallRational // convertFunc turns f into a function that accepts and returns small_rational.SmallRational. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...small_rational.SmallRational) *small_rational.SmallRational { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *small_rational.SmallRational { +func (api *gateAPI) cast(v frontend.Variable) *small_rational.SmallRational { if x, ok := v.(*small_rational.SmallRational); ok { // fast path, no extra heap allocation return x } - var x small_rational.SmallRational + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 64205b80b8..d8c0985ae5 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(gadget.FrontendAPIWrapper{API: api}, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index af8a40fcd6..3c03f75df1 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -38,6 +38,8 @@ type GateAPI interface { // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) + + SumExp17(a, b, c frontend.Variable) frontend.Variable } // GateFunction is a function that evaluates a polynomial over its inputs diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 8a9381bfd7..569201ffa1 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -2,14 +2,12 @@ package gkr_mimc import ( "errors" - "fmt" - "os" "slices" "testing" + "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/plonk" - "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -32,7 +30,13 @@ func TestGkrMiMC(t *testing.T) { In: slices.Clone(vals[:length*2]), } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment)) + allCurves := gnark.Curves() + test.NewAssert(t).CheckCircuit( + circuit, + test.WithValidAssignment(assignment), + test.WithCurves(allCurves[0], allCurves[1:]...), + test.WithBackends(backend.PLONK), + ) } } @@ -72,21 +76,11 @@ func (c *testGkrMiMCCircuit) Define(api frontend.API) error { return nil } -func TestGkrMiMCCompiles(t *testing.T) { - const n = 52000 - circuit := testGkrMiMCCircuit{ - In: make([]frontend.Variable, n), - } - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(27_000_000)) - require.NoError(t, err) - fmt.Println(cs.GetNbConstraints(), "constraints") -} - -type hashTreeCircuit struct { +type merkleTreeCircuit struct { Leaves []frontend.Variable } -func (c hashTreeCircuit) Define(api frontend.API) error { +func (c merkleTreeCircuit) Define(api frontend.API) error { if len(c.Leaves) == 0 { return errors.New("no hashing to do") } @@ -116,39 +110,13 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { - f, err := os.Open(filename) - - if os.IsNotExist(err) { - // actually compile - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) - require.NoError(t, err) - f, err = os.Create(filename) - require.NoError(t, err) - defer f.Close() - _, err = cs.WriteTo(f) - require.NoError(t, err) - return cs - } - - defer f.Close() - require.NoError(t, err) - - cs := plonk.NewCS(ecc.BLS12_377) - - _, err = cs.ReadFrom(f) - require.NoError(t, err) - - return cs -} - -func BenchmarkHashTree(b *testing.B) { +func BenchmarkMerkleTree(b *testing.B) { const size = 1 << 15 // about 2 ^ 16 total hashes - circuit := hashTreeCircuit{ + circuit := merkleTreeCircuit{ Leaves: make([]frontend.Variable, size), } - assignment := hashTreeCircuit{ + assignment := merkleTreeCircuit{ Leaves: make([]frontend.Variable, size), } @@ -156,10 +124,15 @@ func BenchmarkHashTree(b *testing.B) { assignment.Leaves[i] = i } - cs := loadCs(b, "gkrmimc_hashtree.cs", &circuit) + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) - require.NoError(b, cs.IsSolved(w)) + for b.Loop() { + s, err := cs.Solve(w) + require.NoError(b, err) + _ = s + } } diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 266ee00e67..1645552881 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -2,15 +2,21 @@ package gkr_mimc import ( "fmt" - "math/big" "github.com/consensys/gnark-crypto/ecc" + frbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + frbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + frbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + frbls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + frbn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + frbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + frbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" @@ -94,7 +100,7 @@ func RegisterGates(curves ...ecc.ID) error { return err } gateNamer := newGateNamer(curve) - var lastLayerSBox, nonLastLayerSBox func(*big.Int) gkr.GateFunction + var lastLayerSBox, nonLastLayerSBox func(frontend.Variable) gkr.GateFunction switch deg { case 5: lastLayerSBox = addPow5Add @@ -110,12 +116,12 @@ func RegisterGates(curves ...ecc.ID) error { } for i := range len(constants) - 1 { - if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if _, err = gkrgates.Register(nonLastLayerSBox(constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) } } - if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if _, err = gkrgates.Register(lastLayerSBox(constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) } } @@ -123,23 +129,65 @@ func RegisterGates(curves ...ecc.ID) error { } // getParams returns the parameters for the MiMC encryption function for the given curve. -// It also returns the degree of the s-Box -func getParams(curve ecc.ID) ([]big.Int, int, error) { +// It also returns the degree of the s-Box. +func getParams(curve ecc.ID) ([]frontend.Variable, int, error) { switch curve { case ecc.BN254: - return bn254.GetConstants(), 5, nil + c := bn254.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbn254.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS12_381: - return bls12381.GetConstants(), 5, nil + c := bls12381.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls12381.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS12_377: - return bls12377.GetConstants(), 17, nil + c := bls12377.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls12377.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 17, nil case ecc.BLS24_315: - return bls24315.GetConstants(), 5, nil + c := bls24315.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls24315.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS24_317: - return bls24317.GetConstants(), 7, nil + c := bls24317.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls24317.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 7, nil case ecc.BW6_633: - return bw6633.GetConstants(), 5, nil + c := bw6633.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbw6633.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BW6_761: - return bw6761.GetConstants(), 5, nil + c := bw6761.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbw6761.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil default: return nil, -1, fmt.Errorf("unsupported curve ID: %s", curve) } @@ -154,7 +202,7 @@ func (n gateNamer) round(i int) gkr.GateName { return gkr.GateName(fmt.Sprintf("%s%d", string(n), i)) } -func addPow5(key *big.Int) gkr.GateFunction { +func addPow5(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") @@ -166,7 +214,7 @@ func addPow5(key *big.Int) gkr.GateFunction { } // addPow5Add: (in[0]+in[1]+key)⁵ + 2*in[0] + in[2] -func addPow5Add(key *big.Int) gkr.GateFunction { +func addPow5Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") @@ -179,7 +227,7 @@ func addPow5Add(key *big.Int) gkr.GateFunction { } } -func addPow7(key *big.Int) gkr.GateFunction { +func addPow7(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") @@ -191,7 +239,7 @@ func addPow7(key *big.Int) gkr.GateFunction { } // addPow7Add: (in[0]+in[1]+key)⁷ + 2*in[0] + in[2] -func addPow7Add(key *big.Int) gkr.GateFunction { +func addPow7Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") @@ -203,22 +251,17 @@ func addPow7Add(key *big.Int) gkr.GateFunction { } // addPow17: (in[0]+in[1]+key)¹⁷ -func addPow17(key *big.Int) gkr.GateFunction { +func addPow17(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") } - s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Mul(t, s) // s¹⁶ × s + return api.SumExp17(in[0], in[1], key) } } // addPow17Add: (in[0]+in[1]+key)¹⁷ + in[0] + in[2] -func addPow17Add(key *big.Int) gkr.GateFunction { +func addPow17Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") diff --git a/std/permutation/gkr-mimc/gkr-mimc_test.go b/std/permutation/gkr-mimc/gkr-mimc_test.go index 93143b1279..6cc1bde714 100644 --- a/std/permutation/gkr-mimc/gkr-mimc_test.go +++ b/std/permutation/gkr-mimc/gkr-mimc_test.go @@ -11,11 +11,11 @@ import ( "github.com/stretchr/testify/require" ) -type hashTreeCircuit struct { +type merkleTreeCircuit struct { Leaves []frontend.Variable } -func (c hashTreeCircuit) Define(api frontend.API) error { +func (c merkleTreeCircuit) Define(api frontend.API) error { if len(c.Leaves) == 0 { return errors.New("no hashing to do") } @@ -43,7 +43,7 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func BenchmarkGkrPermutations(b *testing.B) { +func BenchmarkMerkleTree(b *testing.B) { circuit, assignment := hashTreeCircuits(50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) @@ -56,15 +56,15 @@ func BenchmarkGkrPermutations(b *testing.B) { require.NoError(b, err) } -func hashTreeCircuits(n int) (circuit, assignment hashTreeCircuit) { +func hashTreeCircuits(n int) (circuit, assignment merkleTreeCircuit) { leaves := make([]frontend.Variable, n) for i := range n { leaves[i] = i } - return hashTreeCircuit{ + return merkleTreeCircuit{ Leaves: make([]frontend.Variable, len(leaves)), - }, hashTreeCircuit{ + }, merkleTreeCircuit{ Leaves: leaves, } }