Skip to content

Commit

Permalink
feat: implemented gates and single entry point (#7)
Browse files Browse the repository at this point in the history
* Added Lean template. Implemented AssertIsDifferent

* Added support for chained ops

* Added du, neg, mac

* Refactored Neg, MulAcc, Div, DivUnchecked, Inverse

* Added Boolean gates and refactoring

* Added lookup and select ops. Added support for callback format

* Added cmp and le

* Moved Lean files to other repo. Added function to create Proj array

* Added ConstantValue implementation

* Prototype single entry point

* Added support for vector arguments

* Added ToBinary and FromBinary. Improved panic messages

* Fixed implicit parameter for from_binary. Added optional gate explicit type.

* Refactored tests and added second example

* Added field in code extractor

* Fixed dependency file

* Replaced Bit with F

* Improved interface to initialise circuit

* Added readme

* Streamlined user API and using NewSchema to identify fields to initialise

* Added support for nested arrays as circuit fields
  • Loading branch information
Eagle941 authored Jun 26, 2023
1 parent 54cbd48 commit f08f001
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 73 deletions.
110 changes: 78 additions & 32 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package extractor

import (
"fmt"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend"
"gnark-extractor/abstractor"
"math/big"
"reflect"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend"
)

type Operand interface {
Expand All @@ -24,15 +27,19 @@ type Gate struct {

func (_ Gate) isOperand() {}

// Input is used to save the position of the argument in the
// list of arguments of the circuit function.
type Input struct {
Index int
}

func (_ Input) isOperand() {}

// Index is the index to be accessed in the array
// Operand[Index]
type Proj struct {
Index int
Operand Operand
Index int
}

func (_ Proj) isOperand() {}
Expand All @@ -46,18 +53,25 @@ type OpKind int
const (
OpAdd OpKind = iota
OpMulAcc
OpMul
OpNeg
OpNegative
OpSub
OpMul
OpDiv
OpDivUnchecked
OpInverse
OpToBinary
OpFromBinary
OpXor
OpAnd
OpOr
OpAnd
OpSelect
OpLookup
OpIsZero
OpCmp
OpAssertEq
OpAssertNotEq
OpAssertIsBool
OpAssertLessEqual
)

func (_ OpKind) isOp() {}
Expand Down Expand Up @@ -91,21 +105,37 @@ func (g *ExGadget) Call(args ...frontend.Variable) []frontend.Variable {
outs[0] = gate
} else {
for i := range g.Outputs {
outs[i] = Proj{i, gate}
outs[i] = Proj{gate, i}
}
}
return outs
}

type ExArgType struct {
Size int
Type *ExArgType
}

type ExArg struct {
Name string
Kind reflect.Kind
Type ExArgType
}

type ExCircuit struct {
Inputs []string
Inputs []ExArg
Gadgets []ExGadget
Code []App
}

type CodeExtractor struct {
Code []App
Gadgets []ExGadget
Field ecc.ID
}

func operandFromArray(arg []frontend.Variable) Operand {
return arg[0].(Proj).Operand
}

func sanitizeVars(args ...frontend.Variable) []Operand {
Expand All @@ -119,8 +149,10 @@ func sanitizeVars(args ...frontend.Variable) []Operand {
case big.Int:
casted := arg.(big.Int)
ops[i] = Const{&casted}
case []frontend.Variable:
ops[i] = operandFromArray(arg.([]frontend.Variable))
default:
fmt.Printf("invalid argument %#v\n", arg)
fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg)
panic("invalid argument")
}
}
Expand All @@ -141,7 +173,7 @@ func (ce *CodeExtractor) MulAcc(a, b, c frontend.Variable) frontend.Variable {
}

func (ce *CodeExtractor) Neg(i1 frontend.Variable) frontend.Variable {
return ce.AddApp(OpNeg, i1, -1)
return ce.AddApp(OpNegative, i1)
}

func (ce *CodeExtractor) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
Expand All @@ -165,12 +197,20 @@ func (ce *CodeExtractor) Inverse(i1 frontend.Variable) frontend.Variable {
}

func (ce *CodeExtractor) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable {
//TODO implement me
panic("implement me")
nbBits := ce.Field.ScalarField().BitLen()
if len(n) == 1 {
nbBits = n[0]
if nbBits < 0 {
panic("Number of bits in ToBinary must be > 0")
}
}
gate := ce.AddApp(OpToBinary, i1, nbBits)
return []frontend.Variable{gate}
}

func (ce *CodeExtractor) FromBinary(b ...frontend.Variable) frontend.Variable {
return ce.AddApp(OpFromBinary, b...)
// Packs in little-endian
return ce.AddApp(OpFromBinary, append([]frontend.Variable{}, b...)...)
}

func (ce *CodeExtractor) Xor(a, b frontend.Variable) frontend.Variable {
Expand All @@ -186,62 +226,68 @@ func (ce *CodeExtractor) And(a, b frontend.Variable) frontend.Variable {
}

func (ce *CodeExtractor) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable {
//TODO implement me
panic("implement me")
return ce.AddApp(OpSelect, b, i1, i2)
}

func (ce *CodeExtractor) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable {
//TODO implement me
panic("implement me")
return ce.AddApp(OpLookup, b0, b1, i0, i1, i2, i3)
}

func (ce *CodeExtractor) IsZero(i1 frontend.Variable) frontend.Variable {
//TODO implement me
panic("implement me")
return ce.AddApp(OpIsZero, i1)
}

func (ce *CodeExtractor) Cmp(i1, i2 frontend.Variable) frontend.Variable {
//TODO implement me
panic("implement me")
return ce.AddApp(OpCmp, i1, i2)
}

func (ce *CodeExtractor) AssertIsEqual(i1, i2 frontend.Variable) {
ce.AddApp(OpAssertEq, i1, i2)
}

func (ce *CodeExtractor) AssertIsDifferent(i1, i2 frontend.Variable) {
//TODO implement me
panic("implement me")
ce.AddApp(OpAssertNotEq, i1, i2)
}

func (ce *CodeExtractor) AssertIsBoolean(i1 frontend.Variable) {
//TODO implement me
panic("implement me")
ce.AddApp(OpAssertIsBool, i1)
}

func (ce *CodeExtractor) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) {
//TODO implement me
panic("implement me")
ce.AddApp(OpAssertLessEqual, v, bound)
}

func (ce *CodeExtractor) Println(a ...frontend.Variable) {
//TODO implement me
panic("implement me")
}

func (ce *CodeExtractor) Compiler() frontend.Compiler {
//TODO implement me
panic("implement me")
}

func (ce *CodeExtractor) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) {
//TODO implement me
panic("implement me")
}

func (ce *CodeExtractor) ConstantValue(v frontend.Variable) (*big.Int, bool) {
//TODO implement me
panic("implement me")
switch v.(type) {
case Const:
return v.(Const).Value, true
case Proj:
switch v.(Proj).Operand.(type) {
case Const:
return v.(Proj).Operand.(Const).Value, true
default:
return nil, false
}
case int64:
return big.NewInt(v.(int64)), true
case big.Int:
casted := v.(big.Int)
return &casted, true
default:
return nil, false
}
}

func (ce *CodeExtractor) DefineGadget(name string, arity int, constructor func(api abstractor.API, args ...frontend.Variable) []frontend.Variable) abstractor.Gadget {
Expand Down
102 changes: 85 additions & 17 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,80 @@ package extractor

import (
"fmt"
"github.com/consensys/gnark/frontend"
"gnark-extractor/abstractor"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
)

func defineExample(api abstractor.API) {
type CircuitWithParameter struct {
In frontend.Variable `gnark:",public"`
Param int
}

func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error {
api.AssertIsEqual(circuit.Param, circuit.In)

return nil
}

func (circuit CircuitWithParameter) Define(api frontend.API) error {
return abstractor.Concretize(api, &circuit)
}

func TestCircuitWithParameter(t *testing.T) {
assignment := CircuitWithParameter{}
assignment.Param = 20
err := CircuitToLean(&assignment, ecc.BW6_756)
if err != nil {
fmt.Println("CircuitToLean error!")
fmt.Println(err.Error())
}
}

type MerkleRecover struct {
Root frontend.Variable `gnark:",public"`
Element frontend.Variable `gnark:",public"`
Path [20]frontend.Variable `gnark:",secret"`
Proof [20]frontend.Variable `gnark:",secret"`
}

func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error {
hash := api.DefineGadget("hash", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
return []frontend.Variable{api.Mul(args[0], args[1])}
})

current := circuit.Element
for i := 0; i < len(circuit.Path); i++ {
leftHash := hash.Call(current, circuit.Proof[i])[0]
rightHash := hash.Call(circuit.Proof[i], current)[0]
current = api.Select(circuit.Path[i], rightHash, leftHash)
}
api.AssertIsEqual(current, circuit.Root)

return nil
}

func (circuit MerkleRecover) Define(api frontend.API) error {
return abstractor.Concretize(api, &circuit)
}

func TestMerkleRecover(t *testing.T) {
assignment := MerkleRecover{}
err := CircuitToLean(&assignment, ecc.BW6_756)
if err != nil {
fmt.Println("CircuitToLean error!")
fmt.Println(err.Error())
}
}

type TwoGadgets struct {
In_1 frontend.Variable
In_2 frontend.Variable
}

func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error {
my_widget := api.DefineGadget("my_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
sum := api.Add(args[0], args[1])
mul := api.Mul(args[0], args[1])
Expand All @@ -20,23 +88,23 @@ func defineExample(api abstractor.API) {
r := api.Mul(mul, snd[0])
return []frontend.Variable{r}
})
i1 := Input{0}
i2 := Input{1}
sum := api.Add(i1, i2)
prod := api.Mul(i1, i2)

sum := api.Add(circuit.In_1, circuit.In_2)
prod := api.Mul(circuit.In_1, circuit.In_2)
my_snd_widget.Call(sum, prod)

return nil
}

func TestExtractor(t *testing.T) {
api := CodeExtractor{
Code: []App{},
Gadgets: []ExGadget{},
}
defineExample(&api)
circuit := ExCircuit{
Inputs: []string{"i1", "i2"},
Gadgets: api.Gadgets,
Code: api.Code,
func (circuit TwoGadgets) Define(api frontend.API) error {
return abstractor.Concretize(api, &circuit)
}

func TestTwoGadgets(t *testing.T) {
assignment := TwoGadgets{}
err := CircuitToLean(&assignment, ecc.BW6_756)
if err != nil {
fmt.Println("CircuitToLean error!")
fmt.Println(err.Error())
}
fmt.Println(ExportCircuit(circuit))
}
Loading

0 comments on commit f08f001

Please sign in to comment.