Skip to content

Commit

Permalink
feat: check SR from batch proofs matches the SR from the DS (#14)
Browse files Browse the repository at this point in the history
* feat: check SR from batch proofs matches the SR from the DS
  • Loading branch information
ToniRamirezM authored Aug 8, 2024
1 parent d0aafd1 commit 624ecc3
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 20 deletions.
24 changes: 8 additions & 16 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package aggregator

import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/json"
Expand Down Expand Up @@ -775,19 +774,6 @@ func (a *Aggregator) buildFinalProof(ctx context.Context, prover proverInterface
finalProof.Public.NewLocalExitRoot = finalBatch.LocalExitRoot.Bytes()
}

// Sanity Check: state root from the proof must match the one from the final batch
finalBatch, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
if err != nil {
return nil, fmt.Errorf("failed to retrieve batch with number [%d]", proof.BatchNumberFinal)
}

if !bytes.Equal(finalProof.Public.NewStateRoot, finalBatch.StateRoot.Bytes()) {
for {
log.Errorf("State root from the proof [%#x] does not match the one from the batch [%#x]. HALTED", finalProof.Public.NewStateRoot, finalBatch.StateRoot.Bytes())
time.Sleep(a.cfg.RetryTime.Duration)
}
}

return finalProof, nil
}

Expand Down Expand Up @@ -1088,7 +1074,7 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover proverInterf
log.Infof("Proof ID for aggregated proof: %v", *proof.ProofID)
log = log.WithFields("proofId", *proof.ProofID)

recursiveProof, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
recursiveProof, _, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get aggregated proof from prover, %w", err)
log.Error(FirstToUpper(err.Error()))
Expand Down Expand Up @@ -1334,7 +1320,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInt

log = log.WithFields("proofId", *proof.ProofID)

resGetProof, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
resGetProof, stateRoot, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get proof from prover, %w", err)
log.Error(FirstToUpper(err.Error()))
Expand All @@ -1343,6 +1329,12 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInt

log.Info("Batch proof generated")

// Sanity Check: state root from the proof must match the one from the batch
if stateRoot != batchToProve.StateRoot {
log.Fatalf("State root from the proof does not match the expected for batch %d: Proof = [%s] Expected = [%s]",
batchToProve.BatchNumber, stateRoot.String(), batchToProve.StateRoot.String())
}

proof.Proof = resGetProof

// NOTE(pg): the defer func is useless from now on, use a different variable
Expand Down
2 changes: 1 addition & 1 deletion aggregator/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type proverInterface interface {
BatchProof(input *prover.StatelessInputProver) (*string, error)
AggregatedProof(inputProof1, inputProof2 string) (*string, error)
FinalProof(inputProof string, aggregatorAddr string) (*string, error)
WaitRecursiveProof(ctx context.Context, proofID string) (string, error)
WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error)
WaitFinalProof(ctx context.Context, proofID string) (*prover.FinalProof, error)
}

Expand Down
71 changes: 68 additions & 3 deletions aggregator/prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@ package prover

import (
"context"
"encoding/json"
"errors"
"fmt"
"math/big"
"net"
"strconv"
"time"

"github.com/0xPolygon/cdk/config/types"
"github.com/0xPolygon/cdk/log"
"github.com/ethereum/go-ethereum/common"
"github.com/iden3/go-iden3-crypto/poseidon"
)

const (
stateRootStartIndex = 19
stateRootFinalIndex = stateRootStartIndex + 8
)

var (
Expand Down Expand Up @@ -240,13 +250,17 @@ func (p *Prover) CancelProofRequest(proofID string) error {

// WaitRecursiveProof waits for a recursive proof to be generated by the prover
// and returns it.
func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, error) {
func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error) {
res, err := p.waitProof(ctx, proofID)
if err != nil {
return "", err
return "", common.Hash{}, err
}
stateRoot, err := GetStateRootFromProof(res.Proof.(*GetProofResponse_RecursiveProof).RecursiveProof)
if err != nil {
return "", common.Hash{}, err
}
resProof := res.Proof.(*GetProofResponse_RecursiveProof)
return resProof.RecursiveProof, nil
return resProof.RecursiveProof, stateRoot, nil
}

// WaitFinalProof waits for the final proof to be generated by the prover and
Expand Down Expand Up @@ -325,3 +339,54 @@ func (p *Prover) call(req *AggregatorMessage) (*ProverMessage, error) {
}
return res, nil
}

// GetStateRootFromProof returns the state root from the proof.
func GetStateRootFromProof(proof string) (common.Hash, error) {
type Publics struct {
Publics []string `mapstructure:"publics"`
}

var publics Publics
err := json.Unmarshal([]byte(proof), &publics)
if err != nil {
log.Errorf("Error unmarshalling proof: %v", err)
return common.Hash{}, err
}

var (
v [8]uint64
j = 0
)

for i := stateRootStartIndex; i < stateRootFinalIndex; i++ {
u64, err := strconv.ParseInt(publics.Publics[i], 10, 64)
if err != nil {
log.Fatal(err)
}
v[j] = uint64(u64)
j++
}
bigSR := fea2scalar(v[:])
hexSR := fmt.Sprintf("%x", bigSR)
if len(hexSR)%2 != 0 {
hexSR = "0" + hexSR
}

return common.HexToHash(hexSR), nil
}

// fea2scalar converts array of uint64 values into one *big.Int.
func fea2scalar(v []uint64) *big.Int {
if len(v) != poseidon.NROUNDSF {
return big.NewInt(0)
}
res := new(big.Int).SetUint64(v[0])
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[1]), 32)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[2]), 64)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[3]), 96)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[4]), 128)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[5]), 160)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[6]), 192)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[7]), 224)) //nolint:gomnd
return res
}
61 changes: 61 additions & 0 deletions aggregator/prover/prover_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package prover_test

import (
"fmt"
"log"
"os"
"testing"

"github.com/0xPolygon/cdk/aggregator/prover"
"github.com/stretchr/testify/require"
)

const (
dir = "../../test/vectors/proofs"
)

type TestStateRoot struct {
Publics []string `mapstructure:"publics"`
}

func TestCalculateStateRoots(t *testing.T) {
var expectedStateRoots = map[string]string{
"1871.json": "0x0ed594d8bc0bb38f3190ff25fb1e5b4fe1baf0e2e0c1d7bf3307f07a55d3a60f",
"1872.json": "0xb6aac97ebb0eb2d4a3bdd40cfe49b6a22d42fe7deff1a8fae182a9c11cc8a7b1",
"1873.json": "0x6f88be87a2ad2928a655bbd38c6f1b59ca8c0f53fd8e9e9d5806e90783df701f",
"1874.json": "0x6f88be87a2ad2928a655bbd38c6f1b59ca8c0f53fd8e9e9d5806e90783df701f",
"1875.json": "0xf4a439c5642a182d9e27c8ab82c64b44418ba5fa04c175a013bed452c19908c9"}

// Read all files in the directory
files, err := os.ReadDir(dir)
if err != nil {
log.Fatal(err)
}

for _, file := range files {
if file.IsDir() {
continue
}

// Read the file
data, err := os.ReadFile(fmt.Sprintf("%s/%s", dir, file.Name()))
if err != nil {
log.Fatal(err)
}

// Get the state root from the batch proof
fileStateRoot, err := prover.GetStateRootFromProof(string(data))
if err != nil {
log.Fatal(err)
}

// Get the expected state root
expectedStateRoot, ok := expectedStateRoots[file.Name()]
if !ok {
log.Fatal("Expected state root not found")
}

// Compare the state roots
require.Equal(t, expectedStateRoot, fileStateRoot.String(), "State roots do not match")
}
}
1 change: 1 addition & 0 deletions test/vectors/proofs/1871.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1872.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1873.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1874.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1875.json

Large diffs are not rendered by default.

0 comments on commit 624ecc3

Please sign in to comment.