Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: check SR from batch proofs matches the SR from the DS #14

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package aggregator

import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/json"
Expand Down Expand Up @@ -775,19 +774,6 @@ func (a *Aggregator) buildFinalProof(ctx context.Context, prover proverInterface
finalProof.Public.NewLocalExitRoot = finalBatch.LocalExitRoot.Bytes()
}

// Sanity Check: state root from the proof must match the one from the final batch
finalBatch, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
if err != nil {
return nil, fmt.Errorf("failed to retrieve batch with number [%d]", proof.BatchNumberFinal)
}

if !bytes.Equal(finalProof.Public.NewStateRoot, finalBatch.StateRoot.Bytes()) {
for {
log.Errorf("State root from the proof [%#x] does not match the one from the batch [%#x]. HALTED", finalProof.Public.NewStateRoot, finalBatch.StateRoot.Bytes())
time.Sleep(a.cfg.RetryTime.Duration)
}
}

return finalProof, nil
}

Expand Down Expand Up @@ -1088,7 +1074,7 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover proverInterf
log.Infof("Proof ID for aggregated proof: %v", *proof.ProofID)
log = log.WithFields("proofId", *proof.ProofID)

recursiveProof, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
recursiveProof, _, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get aggregated proof from prover, %w", err)
log.Error(FirstToUpper(err.Error()))
Expand Down Expand Up @@ -1334,7 +1320,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInt

log = log.WithFields("proofId", *proof.ProofID)

resGetProof, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
resGetProof, stateRoot, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get proof from prover, %w", err)
log.Error(FirstToUpper(err.Error()))
Expand All @@ -1343,6 +1329,12 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInt

log.Info("Batch proof generated")

// Sanity Check: state root from the proof must match the one from the batch
if stateRoot != batchToProve.StateRoot {
log.Fatalf("State root from the proof does not match the expected for batch %d: Proof = [%s] Expected = [%s]",
batchToProve.BatchNumber, stateRoot.String(), batchToProve.StateRoot.String())
}

proof.Proof = resGetProof

// NOTE(pg): the defer func is useless from now on, use a different variable
Expand Down
2 changes: 1 addition & 1 deletion aggregator/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type proverInterface interface {
BatchProof(input *prover.StatelessInputProver) (*string, error)
AggregatedProof(inputProof1, inputProof2 string) (*string, error)
FinalProof(inputProof string, aggregatorAddr string) (*string, error)
WaitRecursiveProof(ctx context.Context, proofID string) (string, error)
WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error)
WaitFinalProof(ctx context.Context, proofID string) (*prover.FinalProof, error)
}

Expand Down
71 changes: 68 additions & 3 deletions aggregator/prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@ package prover

import (
"context"
"encoding/json"
"errors"
"fmt"
"math/big"
"net"
"strconv"
"time"

"github.com/0xPolygon/cdk/config/types"
"github.com/0xPolygon/cdk/log"
"github.com/ethereum/go-ethereum/common"
"github.com/iden3/go-iden3-crypto/poseidon"
)

const (
stateRootStartIndex = 19
stateRootFinalIndex = stateRootStartIndex + 8
)

var (
Expand Down Expand Up @@ -240,13 +250,17 @@ func (p *Prover) CancelProofRequest(proofID string) error {

// WaitRecursiveProof waits for a recursive proof to be generated by the prover
// and returns it.
func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, error) {
func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error) {
res, err := p.waitProof(ctx, proofID)
if err != nil {
return "", err
return "", common.Hash{}, err
}
stateRoot, err := GetStateRootFromProof(res.Proof.(*GetProofResponse_RecursiveProof).RecursiveProof)
if err != nil {
return "", common.Hash{}, err
}
resProof := res.Proof.(*GetProofResponse_RecursiveProof)
return resProof.RecursiveProof, nil
return resProof.RecursiveProof, stateRoot, nil
}

// WaitFinalProof waits for the final proof to be generated by the prover and
Expand Down Expand Up @@ -325,3 +339,54 @@ func (p *Prover) call(req *AggregatorMessage) (*ProverMessage, error) {
}
return res, nil
}

// GetStateRootFromProof returns the state root from the proof.
func GetStateRootFromProof(proof string) (common.Hash, error) {
type Publics struct {
Publics []string `mapstructure:"publics"`
}

var publics Publics
err := json.Unmarshal([]byte(proof), &publics)
if err != nil {
log.Errorf("Error unmarshalling proof: %v", err)
return common.Hash{}, err
}

var (
v [8]uint64
j = 0
)

for i := stateRootStartIndex; i < stateRootFinalIndex; i++ {
u64, err := strconv.ParseInt(publics.Publics[i], 10, 64)
if err != nil {
log.Fatal(err)
Comment on lines +363 to +364
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick

Suggested change
if err != nil {
log.Fatal(err)
var (
v [8]uint64
j = 0
)

}
v[j] = uint64(u64)
j++
}
bigSR := fea2scalar(v[:])
hexSR := fmt.Sprintf("%x", bigSR)
if len(hexSR)%2 != 0 {
hexSR = "0" + hexSR
}

return common.HexToHash(hexSR), nil
}

// fea2scalar converts array of uint64 values into one *big.Int.
func fea2scalar(v []uint64) *big.Int {
if len(v) != poseidon.NROUNDSF {
return big.NewInt(0)
}
res := new(big.Int).SetUint64(v[0])
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[1]), 32)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[2]), 64)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[3]), 96)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[4]), 128)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[5]), 160)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[6]), 192)) //nolint:gomnd
res.Add(res, new(big.Int).Lsh(new(big.Int).SetUint64(v[7]), 224)) //nolint:gomnd
return res
}
61 changes: 61 additions & 0 deletions aggregator/prover/prover_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package prover_test

import (
"fmt"
"log"
"os"
"testing"

"github.com/0xPolygon/cdk/aggregator/prover"
"github.com/stretchr/testify/require"
)

const (
dir = "../../test/vectors/proofs"
)

type TestStateRoot struct {
Publics []string `mapstructure:"publics"`
}

func TestCalculateStateRoots(t *testing.T) {
var expectedStateRoots = map[string]string{
"1871.json": "0x0ed594d8bc0bb38f3190ff25fb1e5b4fe1baf0e2e0c1d7bf3307f07a55d3a60f",
"1872.json": "0xb6aac97ebb0eb2d4a3bdd40cfe49b6a22d42fe7deff1a8fae182a9c11cc8a7b1",
"1873.json": "0x6f88be87a2ad2928a655bbd38c6f1b59ca8c0f53fd8e9e9d5806e90783df701f",
"1874.json": "0x6f88be87a2ad2928a655bbd38c6f1b59ca8c0f53fd8e9e9d5806e90783df701f",
"1875.json": "0xf4a439c5642a182d9e27c8ab82c64b44418ba5fa04c175a013bed452c19908c9"}

// Read all files in the directory
files, err := os.ReadDir(dir)
if err != nil {
log.Fatal(err)
}

for _, file := range files {
if file.IsDir() {
continue
}

// Read the file
data, err := os.ReadFile(fmt.Sprintf("%s/%s", dir, file.Name()))
if err != nil {
log.Fatal(err)
}

// Get the state root from the batch proof
fileStateRoot, err := prover.GetStateRootFromProof(string(data))
if err != nil {
log.Fatal(err)
}

// Get the expected state root
expectedStateRoot, ok := expectedStateRoots[file.Name()]
if !ok {
log.Fatal("Expected state root not found")
}

// Compare the state roots
require.Equal(t, expectedStateRoot, fileStateRoot.String(), "State roots do not match")
}
}
1 change: 1 addition & 0 deletions test/vectors/proofs/1871.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1872.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1873.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1874.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/vectors/proofs/1875.json

Large diffs are not rendered by default.

Loading