Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/generator/backend/template/gkr/gate_testing.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

// IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f
func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool {
var api gateAPI
fWrapped := api.convertFunc(f)

// fix all variables except the i-th one at random points
Expand Down Expand Up @@ -130,6 +131,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom
// FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails.
// Failure could be due to the degree being higher than max or the function not being a polynomial at all.
func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) {
var api gateAPI
fFr := api.convertFunc(f)
bound := uint64(max) + 1
for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 {
Expand All @@ -139,11 +141,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) {
}
return len(p) - 1, nil
}
api.freeElements() // not strictly necessary as few iterations are expected.
}
return -1, fmt.Errorf("could not find a degree: tried up to %d", max)
}

func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error {
var api gateAPI
fFr := api.convertFunc(f)
if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil {
return fmt.Errorf("detected a higher degree than %d", claimedDegree)
Expand All @@ -157,6 +161,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error

// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point.
func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool {
var api gateAPI
x := make({{.FieldPackageName}}.Vector, nbIn)
x.MustSetRandom()
fFr := api.convertFunc(f)
Expand Down
111 changes: 73 additions & 38 deletions internal/generator/backend/template/gkr/gkr.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []{{ .ElementType
for i, uniqueI := range injectionLeftInv { // map from all to unique
inputEvaluations[i] = &uniqueInputEvaluations[uniqueI]
}

gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*{{ .ElementType }}))
var api gateAPI
gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*{{ .ElementType }}))
}

evaluation.Mul(&evaluation, &gateEvaluation)
Expand Down Expand Up @@ -230,7 +230,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {
gJ := make([]{{ .ElementType }}, degGJ)
var mu sync.Mutex
computeAll := func(start, end int) { // compute method to allow parallelization across instances
var step {{ .ElementType }}
var (
step {{ .ElementType }}
api gateAPI
)

res := make([]{{ .ElementType }}, degGJ)

Expand Down Expand Up @@ -260,10 +263,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {
for i := range gateInput {
gateInput[i] = &mlEvals[eIndex+1+i]
}
summand := wire.Gate.Evaluate(api, gateInput...).(*{{ .ElementType }})
summand := wire.Gate.Evaluate(&api, gateInput...).(*{{ .ElementType }})
summand.Mul(summand, &mlEvals[eIndex])
res[d].Add(&res[d], summand) // collect contributions into the sum from start to end
eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml)
api.freeElements()
}
}
mu.Lock()
Expand Down Expand Up @@ -663,6 +667,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment {
}
}

var api gateAPI
ins := make([]{{ .ElementType }}, maxNbIns)
for i := range nbInstances {
for wI, w := range wires {
Expand Down Expand Up @@ -720,52 +725,69 @@ func frToBigInts(dst []*big.Int, src []{{ .ElementType }}) {


// gateAPI implements gkr.GateAPI.
type gateAPI struct{}

var api gateAPI
// It uses a synchronous memory pool underneath to minimize heap allocations.
type gateAPI struct {
allocated []*{{ .ElementType }}
nbUsed int
}

func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
var res {{ .ElementType }} // TODO Heap allocated. Keep an eye on perf
res.Add(cast(i1), cast(i2))
func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
res := api.newElement()
res.Add(api.cast(i1), api.cast(i2))
for _, v := range in {
res.Add(&res, cast(v))
res.Add(res, api.cast(v))
}
return &res
return res
}

func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable {
var prod {{ .ElementType }}
prod.Mul(cast(b), cast(c))
res := cast(a)
res.Add(res, &prod)
return &res
func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable {
prod := api.newElement()
prod.Mul(api.cast(b), api.cast(c))
res := api.cast(a)
res.Add(res, prod)
return res
}

func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable {
var res {{ .ElementType }}
res.Neg(cast(i1))
return &res
func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable {
res := api.newElement()
res.Neg(api.cast(i1))
return res
}

func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
var res {{ .ElementType }}
res.Sub(cast(i1), cast(i2))
func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
res := api.newElement()
res.Sub(api.cast(i1), api.cast(i2))
for _, v := range in {
res.Sub(&res, cast(v))
res.Sub(res, api.cast(v))
}
return &res
return res
}

func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
var res {{ .ElementType }}
res.Mul(cast(i1), cast(i2))
func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
res := api.newElement()
res.Mul(api.cast(i1), api.cast(i2))
for _, v := range in {
res.Mul(&res, cast(v))
res.Mul(res, api.cast(v))
}
return &res
return res
}

func (api *gateAPI) SumExp17(a,b,c frontend.Variable) frontend.Variable {
var x {{ .ElementType }}

x.Add(api.cast(a), api.cast(b))
x.Add(&x, api.cast(c))

res := api.newElement()

res.Mul(&x, &x) // x²
res.Mul(res, res) // x⁴
res.Mul(res, res) // x⁸
res.Mul(res, res) // x¹⁶
return res.Mul(res, &x) // x¹⁷
}

func (gateAPI) Println(a ...frontend.Variable) {
func (api *gateAPI) Println(a ...frontend.Variable) {
toPrint := make([]any, len(a))
var x {{ .ElementType }}

Expand All @@ -783,30 +805,43 @@ func (gateAPI) Println(a ...frontend.Variable) {
fmt.Println(toPrint...)
}

func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} {
func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} {
inVar := make([]frontend.Variable, len(in))
for i := range in {
inVar[i] = &in[i]
}
return f(api, inVar...).(*{{ .ElementType }})
}

// Put all elements back in the pool.
func (api *gateAPI) freeElements() {
api.nbUsed = 0
}

func (api *gateAPI) newElement() *{{ .ElementType }} {
api.nbUsed++
if api.nbUsed >= len(api.allocated) {
api.allocated = append(api.allocated, new({{ .ElementType }}))
}
return api.allocated[api.nbUsed-1]
}

type gateFunctionFr func(...{{ .ElementType }}) *{{ .ElementType }}

// convertFunc turns f into a function that accepts and returns {{ .ElementType }}.
func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr {
func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr {
return func(in ...{{ .ElementType }}) *{{ .ElementType }} {
return api.evaluate(f, in...)
}
}

func cast(v frontend.Variable) *{{ .ElementType }} {
func (api *gateAPI) cast(v frontend.Variable) *{{ .ElementType }} {
if x, ok := v.(*{{ .ElementType }}); ok { // fast path, no extra heap allocation
return x
}
var x {{ .ElementType }}
x := api.newElement()
if _, err := x.SetInterface(v); err != nil {
panic(err)
}
return &x
return x
}
6 changes: 4 additions & 2 deletions internal/generator/backend/template/gkr/solver_hints.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func SolveHint(data *SolvingData) hint.Hint {

gateIns := make([]frontend.Variable, data.maxNbIn)
outsI := 0
insI := 1 // skip the first input, which is the instance index
insI := 1 // skip the first input, which is the instance index
var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations.
for wI := range data.circuit {
w := &data.circuit[wI]
if w.IsInput() { // read from provided input
Expand All @@ -110,7 +111,8 @@ func SolveHint(data *SolvingData) hint.Hint {
gateIns[i] = &data.assignment[inWI][instanceI]
}

data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element))
data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element))
api.freeElements()
}
if w.IsOutput() {
data.assignment[wI][instanceI].BigInt(outs[outsI])
Expand Down
5 changes: 5 additions & 0 deletions internal/gkr/bls12-377/gate_testing.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading